0%

论文阅读 [精读]-TORCH.FX: PRACTICAL PROGRAM CAPTURE AND TRANSFORMATION FOR DEEP LEARNING IN PYTHON

我理解这篇论文就是 torch.fx 的论文,作者是站在设计 torch.fx 的角度思考 “我们为什么要这么做”,把他们的一系列实现整理成了论文发了出来。

摘要

摘要部分就是说很多 python 框架虽然使用 eager execution 提升了易用性,但是在真实的落地场景中,用户其实希望对模型做性能优化,可视化,分析和硬件调优。为了满足这个 gap,作者设计了 Torch.fx,一个纯 python 的程序捕捉和变换的库。

Eager execution: 其实就是 define-by-run,可以让用户直接用 high-level 编程语言去定义计算图,可自动进行求导,进而训练和预测。非常方便开发

Introduction

先区分了以下两种模式:

  • define-and-run:早期的架构,定义很多 API。用户来操作这些 api 去构建一些图的 IR,进而框架去进行图优化、算子融合等编译优化。问题在于,用户需要同时懂 host language 和一种运行时的 domain-Specific 的语言,后者尝尝很麻烦。比如 python 的 pdb debug 系统。
  • define-by-run(eager mode):PyTorch 或者 TensorFlow eager 等支持的。直接去翻译高层语言,去支持自动求导等过程,方便开发。但这时,就只能用一些 just-in-time 的编译方法,而不是 ahead-of-time 里有的很多算子融合等技术,因此效果不如上面的好。

事实上,只是使用已知的计算图的一些 IR 结构就能方便的做一些图优化的手段。想要使用上他们,就需要一个 eager mode 框架也可以去捕捉程序结构,进而获取上面的信息。

其实,TorchScript 通过捕捉 python 程序所有的 AST 是可以做这件事的。但问题在于,它捕捉所有的程序,cost 太大,其实我们只需要捕捉计算图 (nn.module)。因为优化手段基本上都局限在 DAG 图的范围内,因此这个捕捉手段是可以化简的。

因此作者实现了 torch.fx, 聚焦于 dag 图的表示,设计了一系列用户接口来让图满足这种格式,同时对各种图优化进行了支持

  • 分析了图优化之于 AI 的重要性
  • 实现了纯 python 的满足上述条件的库
  • 一个只有 6 条指令的 IR,便于表示 dag 图
  • 可以进行图变化,进而生成优化后的代码返回给 host language
  • 分析了 torch.fx 的使用实例

BACKGROUND

这一部分,提出了 eager-mode 的程序需要的几个要素

Capturing Program Structure

这一部分对比了已有的框架们获取程序表示的几种方式:

  • trace 每一个操作的情况,需要用户给出样例输入。PyTorch’s jit.trace 采用。
  • trace 每个的抽象的值而不是样例输入,不用用户给出样例输入。Tensor- Flow’s tf.function 采用
    • 上两种的问题都是,只能支持 python 功能的一个子集,并且错误的追踪不精确
    • 另一个问题是只能用一些定义的 API,而不是直接的 python 语言做开发。
  • 允许让用户直接用 embedded programming language 开发,如 TorchScript。
    • 问题是这种框架实现特别复杂,而且,还是只能支持 python 功能的一个子集 (比上面更大的子集)
  • 一些框架提供 python 到别的语言的接口,然后用别的语言的方式做优化,如 Swift for TensorFlow。
    • 问题是需要退出 python 生态。但这个生态,包含里面的各种库,只有 python 有好的实现。

Specializing Programs

这一部分讲特化,其实是关于 shape 的定义。

这个表达式在 python 里可以后推断各种 type 的输入,但对于计算图它需要比较确定。

  • PyTorch’s jit.trace:只支持固定形状 (样例输入) 的输入,别的就炸了
  • LazyTensor:可以 just-in-time 进行 tracing。更灵活,但由于每次输入需要重新 trace,代价太大
  • JAX’s jit combinator:每次重新 capture 不是必要的,可以根据输入决定是否需要重新 tracing。但可能会导致无法预测的各种运行时的问题。

Intermediate Representation Design

这一部分讲 IR 的设计,总体而言,IR 复杂,会使得优化效果更好,但也更难实现。

  • Language:有些框架是跨编程语言的,比如 PyTorch’s JIT and MXNet 用 c++ 作为他们的 data structure。运行时表现更好,更好序列化,但是需要用户在开发 python 时有额外的学习成本

  • Control flow:很多 network 不需要 control flow,只要一系列 if-statements or loops 操作,称作 basic block program。

    • basic block program 可以直接表示成一个 DAG 图,MLP,CNN,transformer 都符合这个情况
    • RNN 不符合,因为什么时候停止需要动态推断。因此每一次 state 的内部都是一个 basic block,总体有一个控制流
  • State:这个 state 其实就是指模型参数。以为 pytorch 支持 aliasing 和 mutation 的语法,因此需要检查对 state 的操作是合法的。
    • TorchScript IR 支持别名分析、指代分析,对 IR 做变换。代价很大,因为每一个表达式都要计算别名、指代消解。而且很多函数可能有改全局变量这种阴间操作。这种方法会降低优化的能力,但用户友好。
    • JAX 使用一些别的框架来做这件事,比如把模型用 FLAX 封装一遍。任何一种变换都比较复杂,因为需要跨框架的交互。

DESIGN PRINCIPLES

根据上面的分析,提出了 torch.fx 的几个设计理念:

  • 更关注于已有的、经典的模型,不去执着于 long-tailed、复杂的实现
  • 更多使用开发者熟悉的 python 的数据结构、已支持的算子
  • 让程序的捕获具有灵活性,方便用户去实现的自己的 long-tailed 需求

TORCH.FX OVERVIEW

前面作者的思考和分析讲完了,这一部分讲作者到底怎么设计。

不在捕捉的时候特化计算图,而是在优化的时候做。

一个变换的例子如上图,直接用 python 语言书写图变换的方法

capture

  • GraphModule 是 torch. nn. Module 的子类,其 forward 方法运行捕获的 Graph。我们可以打印此图的 Nodes 以查看捕获的 IR。
  • placeholder 节点表示输入,单个 output 节点表示 Graph 的结果。
  • call_function 节点直接引用了它将调用的 Python 函数。
  • call_method 节点直接调用其第一个参数的方法。
  • Graph 被重组为 Python 代码(traced.code)以供调用。

捕捉的例子如下:

IR

设计了只有 6 个语句的 IR,非常简单

  • torch. fx 的中间表示(IR)由一个 Python 数据结构 Graph 来做的。

  • 这个 Graph 实际上是一个包含一系列 Node 的线性表。

    • 节点有一个字符串操作码 opcode,描述节点代表什么类型的操作(操作码的语义可以在附录 A.1 中找到)。
    • 节点有一个关联的目标,它是调用节点(call_Module、call_function 和 call_method)的调用目标。
    • 节点有 args 和 kwargs,在 trace 期间它们一起表示 Python 调用约定中的目标参数(每个 opcode 对应的 args 和 kwargs 的语义可以在附录 A.2 中找到)。
    • 节点之间的数据依赖关系表示为 args 和 kwargs 中对其他节点的引用。
  • torch. fx 将程序的状态存储在 GraphModule 类中。

    • GraphModule 是转换程序的容器,暴露转换后生成的代码,并提供 nn. Module 类似的参数管理 APIs。
    • GraphModule 可以在任何可以使用普通的 nn. Module 的地方使用,以提供转换后的代码和 PyTorch 生态系统的其余部分之间的互操作性。

Source-to-Source Transformation

transformation 的最后一步是重新从 IR 翻译回 python 语言,而不是到其他生态系统。同时也可以继续进行 transformation。

CASE STUDIES AND EVALUATION

大概就是说,torch.fx 是做 IR 抽象还原的,所以 IR 很重要,比较了一下 IR 可读性、简便性。

左右一对比,确实可读性、简便性胜出。

还做了个实验比较效果,比如算子融合之类的,发现确实都可以做。

思考

  • 感觉这个就是让工程师可以很方便的 “用 python 优化 python”,可以自己定义自己的 pass,还能非常方便的再变换回 python 语言
  • 我其实有点好奇这个东西的作用在哪里,后来看了一下是方便开发者可以直接把搞好的模型出来做 python-python 的转换:
    • 比如把所有的 op 都换一下
    • 比如不改原始代码,直接提取 pretrain model 的中间层的输出来做下游任务
Powered By Valine
v1.5.2