forked from chenzomi12/AISystem
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add readme and upload autodiff file.
- Loading branch information
chenzomi
committed
Sep 22, 2022
1 parent
7b25c8d
commit 4b6b52a
Showing
9 changed files
with
810 additions
and
2 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# 前向操作符重载自动微分实现\n", | ||
"\n", | ||
"在这篇文章里,ZOMI会介绍是怎么实现自动微分的,因为代码量非常小,也许你也可以写一个玩玩。前面的文章当中,已经把自动微分的原理深入浅出的讲了一下,也引用了非常多的论文。有兴趣的可以顺着综述A survey这篇深扒一下。\n", | ||
"\n", | ||
"-【自动微分原理】01. 原理介绍\n", | ||
"-【自动微分原理】02. 正反模式\n", | ||
"-【自动微分原理】03. 具体实现\n", | ||
"-【自动微分原理】04. 前向操作符重载实现\n", | ||
"\n", | ||
"# 前向自动微分原理\n", | ||
"\n", | ||
"了解自动微分的不同实现方式非常有用。在这里呢,我们将介绍主要的前向自动微分,通过Python这个高级语言来实现操作符重载。在正反向模式中的这篇的文章中,我们介绍了前向自动微分的基本数学原理。\n", | ||
"\n", | ||
"> 前向模式(Forward Automatic Differentiation,也叫做 tangent mode AD)或者前向累积梯度(前向模式)\n", | ||
"\n", | ||
"前向自动微分中,从计算图的起点开始,沿着计算图边的方向依次向前计算,最终到达计算图的终点。它根据自变量的值计算出计算图中每个节点的值 以及其导数值,并保留中间结果。一直得到整个函数的值和其导数值。整个过程对应于一元复合函数求导时从最内层逐步向外层求导。\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"简单确实简单,可以总结前向自动微分关键步骤为:\n", | ||
"\n", | ||
"- 分解程序为一系列已知微分规则的基础表达式的组合\n", | ||
"- 根据已知微分规则给出各基础表达式的微分结果\n", | ||
"- 根据基础表达式间的数据依赖关系,使用链式法则将微分结果组合完成程序的微分结果\n", | ||
"\n", | ||
"而通过Python高级语言,进行操作符重载后的关键步骤其实也相类似:\n", | ||
"\n", | ||
"- 分解程序为一系列已知微分规则的基础表达式组合,并使用高级语言的重载操作\n", | ||
"- 在重载运算操作的过程中,根据已知微分规则给出各基础表达式的微分结果\n", | ||
"- 根据基础表达式间的数据依赖关系,使用链式法则将微分结果组合完成程序的微分结果\n", | ||
"\n", | ||
"# 具体实现\n", | ||
"\n", | ||
"首先呢,我们需要加载通用的numpy库,用于实际运算的,如果不用numpy,在python中也可以使用math来代替。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"前向自动微分又叫做tangent mode AD,所以我们准备一个叫做ADTangent的类,这类初始化的时候有两个参数,一个是 x,表示输入具体的数值;另外一个是 dx,表示经过对自变量 x 求导后的值。\n", | ||
"\n", | ||
"需要注意的是,操作符重载自动微分不像源码转换可以给出求导的公式,一般而言并不会给出求导公式,而是直接给出最后的求导值,所以就会有 dx 的出现。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class ADTangent:\n", | ||
" \n", | ||
" # 自变量 x,对自变量进行求导得到的 dx\n", | ||
" def __init__(self, x, dx):\n", | ||
" self.x = x\n", | ||
" self.dx = dx\n", | ||
" \n", | ||
" # 重载 str 是为了方便打印的时候,看到输入的值和求导后的值\n", | ||
" def __str__(self):\n", | ||
" context = f'value:{self.x:.4f}, grad:{self.dx}'\n", | ||
" return context" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"下面是核心代码,也就是操作符重载的内容,在 ADTangent 类中通过 Python 私有函数重载加号,首先检查输入的变量 other 是否属于 ADTangent,如果是那么则把两者的自变量 x 进行相加。\n", | ||
"\n", | ||
"其中值得注意的就是 dx 的计算,因为是正向自动微分,因此每一个前向的计算都会有对应的反向求导计算。求导的过程是这个程序的核心,不过不用担心的是这都是最基础的求导法则。最后返回自身的对象 ADTangent(x, dx)。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
" def __add__(self, other):\n", | ||
" if isinstance(other, ADTangent):\n", | ||
" x = self.x + other.x\n", | ||
" dx = self.dx + other.dx\n", | ||
" elif isinstance(other, float):\n", | ||
" x = self.x + other\n", | ||
" dx = self.dx\n", | ||
" else:\n", | ||
" return NotImplementedError\n", | ||
" return ADTangent(x, dx)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"下面则是对减号、乘法、log、sin几个操作进行操作符重载,正向的重载的过程比较简单,基本都是按照上面的 __add__ 的代码讨论来实现。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
" def __sub__(self, other):\n", | ||
" if isinstance(other, ADTangent):\n", | ||
" x = self.x - other.x\n", | ||
" dx = self.dx - other.dx\n", | ||
" elif isinstance(other, float):\n", | ||
" x = self.x - other\n", | ||
" ex = self.dx\n", | ||
" else:\n", | ||
" return NotImplementedError\n", | ||
" return ADTangent(x, dx)\n", | ||
"\n", | ||
" def __mul__(self, other):\n", | ||
" if isinstance(other, ADTangent):\n", | ||
" x = self.x * other.x\n", | ||
" dx = self.x * other.dx + self.dx * other.x\n", | ||
" elif isinstance(other, float):\n", | ||
" x = self.x * other\n", | ||
" dx = self.dx * other\n", | ||
" else:\n", | ||
" return NotImplementedError\n", | ||
" return ADTangent(x, dx)\n", | ||
"\n", | ||
" def log(self):\n", | ||
" x = np.log(self.x)\n", | ||
" dx = 1 / self.x * self.dx\n", | ||
" return ADTangent(x, dx)\n", | ||
"\n", | ||
" def sin(self):\n", | ||
" x = np.sin(self.x)\n", | ||
" dx = self.dx * np.cos(self.x)\n", | ||
" return ADTangent(x, dx)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"以公式5为例:\n", | ||
"\n", | ||
"$$\n", | ||
"f(x1,x2)=ln(x1)+x1x2−sin(x2) \\tag{1}\n", | ||
"$$\n", | ||
"\n", | ||
"因为是基于 ADTangent 类进行了操作符重载,因此在初始化自变量 x 和 y 的值需要使用 ADTangent 来初始化,然后通过代码 f = ADTangent.log(x) + x * y - ADTangent.sin(y) 来实现。\n", | ||
"\n", | ||
"由于这里是求 f 关于自变量 x 的导数,因此初始化数据的时候呢,自变量 x 的 dx 设置为1,而自变量 y 的 dx 设置为0。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 91, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"value:11.6521, grad:5.5\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"x = ADTangent(x=2., dx=1)\n", | ||
"y = ADTangent(x=5., dx=0)\n", | ||
"f = ADTangent.log(x) + x * y - ADTangent.sin(y)\n", | ||
"print(f)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"从输出结果来看,正向计算的输出结果是跟上面图相同,而反向的导数求导结果也与上图相同。下面一个是 Pytroch 的实现结果对比,最后呢MindSpore的实现结果对比。可以看到呢,上面的简单实现的自动微分结果和 Pytroch 、MindSpore是相同的。还是很有意思的。" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 65, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"tensor([11.6521], grad_fn=<SubBackward0>)\n", | ||
"tensor([5.5000])\n", | ||
"tensor([1.7163])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"from torch.autograd import Variable\n", | ||
"\n", | ||
"x = Variable(torch.Tensor([2.]), requires_grad=True)\n", | ||
"y = Variable(torch.Tensor([5.]), requires_grad=True)\n", | ||
"f = torch.log(x) + x * y - torch.sin(y)\n", | ||
"f.backward()\n", | ||
"\n", | ||
"print(f)\n", | ||
"print(x.grad)\n", | ||
"print(y.grad)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 130, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[11.65207]\n", | ||
"5.5\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import mindspore.nn as nn\n", | ||
"from mindspore import Parameter, Tensor\n", | ||
"\n", | ||
"class Fun(nn.Cell):\n", | ||
" def __init__(self):\n", | ||
" super(Fun, self).__init__()\n", | ||
"\n", | ||
" def construct(self, x, y):\n", | ||
" f = ops.log(x) + x * y - ops.sin(y)\n", | ||
" return f\n", | ||
" \n", | ||
"x = Tensor(np.array([2.], np.float32))\n", | ||
"y = Tensor(np.array([5.], np.float32))\n", | ||
"f = Fun()(x, y)\n", | ||
"\n", | ||
"grad_all = ops.GradOperation()\n", | ||
"grad = grad_all(Fun())(x, y)\n", | ||
"\n", | ||
"print(f)\n", | ||
"print(grad[0])\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.