这篇文章主要介绍“PyTorch中tensor.detach()和tensor.data的区别有哪些”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“PyTorch中tensor.detach()和tensor.data的区别有哪些”文章能帮助大家解决问题。
PyTorch中 tensor.detach() 和 tensor.data 的区别
以 a.data, a.detach() 为例:
两种方法均会返回和a相同的tensor,且与原tensor a 共享数据,一方改变,则另一方也改变。
所起的作用均是将变量tensor从原有的计算图中分离出来,分离所得tensor的requires_grad = False。
不同点:
data是一个属性,.detach()是一个方法;data是不安全的,.detach()是安全的;
>>> a = torch.tensor([1,2,3.], requires_grad =True) >>> out = a.sigmoid() >>> c = out.data >>> c.zero_() tensor([ 0., 0., 0.]) >>> out # out的数值被c.zero_()修改 tensor([ 0., 0., 0.]) >>> out.sum().backward() # 反向传播 >>> a.grad # 这个结果很严重的错误,因为out已经改变了 tensor([ 0., 0., 0.])
为什么.data是不安全的?
这是因为,当我们修改分离后的tensor,从而导致原tensora发生改变。PyTorch的自动求导Autograd是无法捕捉到这种变化的,会依然按照求导规则进行求导,导致计算出错误的导数值。
其风险性在于,如果我在某一处修改了某一个变量,求导的时候也无法得知这一修改,可能会在不知情的情况下计算出错误的导数值。
>>> a = torch.tensor([1,2,3.], requires_grad =True) >>> out = a.sigmoid() >>> c = out.detach() >>> c.zero_() tensor([ 0., 0., 0.]) >>> out # out的值被c.zero_()修改 !! tensor([ 0., 0., 0.]) >>> out.sum().backward() # 需要原来out得值,但是已经被c.zero_()覆盖了,结果报错 RuntimeError: one of the variables needed for gradient computation has been modified by an
那么.detach()为什么是安全的?
使用.detach()的好处在于,若是出现上述情况,Autograd可以检测出某一处变量已经发生了改变,进而以如下形式报错,从而避免了错误的求导。
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
从以上可以看出,是在前向传播的过程中使用就地操作(In-place operation)导致了这一问题,那么就地操作是什么呢?
补充:pytorch中的detach()函数的作用
detach()
官方文档中,对这个方法是这么介绍的。
返回一个新的从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度 如果 被 detach
的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor
import torch from torch.nn import init from torch.autograd import Variable t1 = torch.FloatTensor([1., 2.]) v1 = Variable(t1) t2 = torch.FloatTensor([2., 3.]) v2 = Variable(t2) v3 = v1 + v2 v3_detached = v3.detach() v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值 print(v3, v3_detached) # v3 中tensor 的值也会改变
能用来干啥
可以对部分网络求梯度。
如果我们有两个网络 , 两个关系是这样的 现在我们想用 来为B网络的参数来求梯度,但是又不想求A网络参数的梯度。我们可以这样:
# y=A(x), z=B(y) 求B中参数的梯度,不求A中参数的梯度 y = A(x) z = B(y.detach()) z.backward()