Pytorch 自动微分
发布日期:2021-07-01 03:35:11
浏览次数:2
分类:技术文章
本文共 2649 字,大约阅读时间需要 8 分钟。
Tensor.requires_grad = True
记录对Tensor的所有操作,后序.backward()
自动计算所有梯度到.grad
属性
import torchx = torch.ones(2,2, requires_grad=True) # 默认是Falseprint(x)tensor([[1., 1.], [1., 1.]], requires_grad=True)
- 停止记录调用
.detach()
x.detach_()print(x.requires_grad) # False
.grad_fn
保存了创建张量的 Function 的引用
x = torch.ones(2,2, requires_grad=True)y = x + 2print(y)print(y.grad_fn)tensor([[3., 3.], [3., 3.]], grad_fn=)
z = y*y*3out = z.mean()print(z, out)tensor([[27., 27.], [27., 27.]], grad_fn=) tensor(27., grad_fn= )
# requires_grad 默认为 Falsea = torch.randn(2, 2)a = ((a*3)/(a-1))print(a.requires_grad) # Falseb = (a*a).sum()print(b.grad_fn) # Nonea.requires_grad_(True) # 设置为 Trueprint(a.requires_grad) # Trueb = (a*a).sum()print(b.grad_fn)#
backward()
后向传播
z = y*y*3y = x+2计算 d(out)/dx
o u t = 1 4 ( ∑ 3 ( x i + 2 ) 2 ) → d o u t d x i = 3 2 ( x i + 2 ) out = \frac{1}{4}(\sum3(x_i+2)^2) \rightarrow \frac{d_{out}}{dx_i} = \frac{3}{2}(x_i+2) out=41(∑3(xi+2)2)→dxidout=23(xi+2)
x i = 1 , d o u t / d x i = 4.5 x_i = 1, d_{out}/dx_i = 4.5 xi=1,dout/dxi=4.5out.backward()print(y.grad) # None, 为什么?是 Noneprint(x.grad)tensor([[4.5000, 4.5000], [4.5000, 4.5000]])
J = ( ∂ y 1 ∂ x 1 ⋯ ∂ y m ∂ x 1 ⋮ ⋱ ⋮ ∂ y 1 ∂ x n ⋯ ∂ y m ∂ x n ) J=\left(\begin{array}{ccc}\frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_{1}}{\partial x_{n}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}\end{array}\right) J=⎝⎜⎛∂x1∂y1⋮∂xn∂y1⋯⋱⋯∂x1∂ym⋮∂xn∂ym⎠⎟⎞
- 当又使用了一个函数 l = g ( y ) l = g(y) l=g(y),v 是 l l l 对 y y y 的导数,链式求导相乘,得到 l l l 对 x x x 的导数 J ⋅ v = ( ∂ y 1 ∂ x 1 ⋯ ∂ y m ∂ x 1 ⋮ ⋱ ⋮ ∂ y 1 ∂ x n ⋯ ∂ y m ∂ x n ) ( ∂ l ∂ y 1 ⋮ ∂ l ∂ y m ) = ( ∂ l ∂ x 1 ⋮ ∂ l ∂ x n ) J \cdot v=\left(\begin{array}{ccc}\frac{\partial y_{1}}{\partial x_{1}} & \cdots & \frac{\partial y_{m}}{\partial x_{1}} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_{1}}{\partial x_{n}} & \cdots & \frac{\partial y_{m}}{\partial x_{n}}\end{array}\right)\left(\begin{array}{c}\frac{\partial l}{\partial y_{1}} \\ \vdots \\ \frac{\partial l}{\partial y_{m}}\end{array}\right)=\left(\begin{array}{c}\frac{\partial l}{\partial x_{1}} \\ \vdots \\ \frac{\partial l}{\partial x_{n}}\end{array}\right) J⋅v=⎝⎜⎛∂x1∂y1⋮∂xn∂y1⋯⋱⋯∂x1∂ym⋮∂xn∂ym⎠⎟⎞⎝⎜⎛∂y1∂l⋮∂ym∂l⎠⎟⎞=⎝⎜⎛∂x1∂l⋮∂xn∂l⎠⎟⎞
上面代码改为:
v = torch.tensor(2, dtype=torch.float)out.backward(v)print(x.grad)# 梯度乘以了 2tensor([[9., 9.], [9., 9.]])
- 评估阶段可以使用
with torch.no_grad():
不需要梯度计算和更新
print(x.requires_grad) # Trueprint((x ** 2).requires_grad) # True# 取消梯度记录with torch.no_grad(): print((x ** 2).requires_grad) # False
转载地址:https://michael.blog.csdn.net/article/details/111657579 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
网站不错 人气很旺了 加油
[***.192.178.218]2024年05月04日 19时55分50秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
source insight快捷键及使用技巧
2019-05-01
映 射 ALT 键
2019-05-01
vim使用快捷键F4生成文件头注释、F5生成main函数模板、F6生成.h文件框架模板
2019-05-01
OV5620的视频驱动
2019-05-01
C++中两个类交叉定义或递归定义的解决办法
2019-05-01
记一次Hive 行转列 引起的GC overhead limit exceeded
2019-05-01
OpenGL ES八 - 交叉存取顶点数据
2019-05-01
crontab定时任务写法
2019-05-01
nginx: [emerg] unknown directive "if($remote_addr" in /usr/local/tools/nginx/conf/nginx.conf:57
2019-05-01
module pip has no attribute main问题解决
2019-05-01
LeetCode 134.Gas Station (加油站)
2019-05-01
Python之命名元组 (namedtuple)
2019-05-01
使用libpcap过滤arp
2019-05-01
[转帖]Robots.txt指南
2019-05-01
正则表达式简介(微软)--6.优先权顺序
2019-05-01
多用户与多租户的区别
2019-05-01
Python自动化运维 - day14 - JavaScript基础
2019-05-02
oracle保存小数点前为"0"的问题
2019-05-02