溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

PyTorch中tensor.detach()和tensor.data的區別有哪些

發布時間:2023-04-07 09:50:11 來源:億速云 閱讀:137 作者:iii 欄目:開發技術

這篇文章主要介紹“PyTorch中tensor.detach()和tensor.data的區別有哪些”的相關知識,小編通過實際案例向大家展示操作過程,操作方法簡單快捷,實用性強,希望這篇“PyTorch中tensor.detach()和tensor.data的區別有哪些”文章能幫助大家解決問題。

PyTorch中 tensor.detach() 和 tensor.data 的區別

以 a.data, a.detach() 為例:
兩種方法均會返回和a相同的tensor,且與原tensor a 共享數據,一方改變,則另一方也改變。

所起的作用均是將變量tensor從原有的計算圖中分離出來,分離所得tensor的requires_grad = False。

不同點:

data是一個屬性,.detach()是一個方法;data是不安全的,.detach()是安全的;

>>> a = torch.tensor([1,2,3.], requires_grad =True)
>>> out = a.sigmoid()
>>> c = out.data
>>> c.zero_()
tensor([ 0., 0., 0.])

>>> out                   #  out的數值被c.zero_()修改
tensor([ 0., 0., 0.])

>>> out.sum().backward()  #  反向傳播
>>> a.grad                #  這個結果很嚴重的錯誤,因為out已經改變了
tensor([ 0., 0., 0.])

為什么.data是不安全的?

這是因為,當我們修改分離后的tensor,從而導致原tensora發生改變。PyTorch的自動求導Autograd是無法捕捉到這種變化的,會依然按照求導規則進行求導,導致計算出錯誤的導數值。

其風險性在于,如果我在某一處修改了某一個變量,求導的時候也無法得知這一修改,可能會在不知情的情況下計算出錯誤的導數值。

>>> a = torch.tensor([1,2,3.], requires_grad =True)
>>> out = a.sigmoid()
>>> c = out.detach()
>>> c.zero_()
tensor([ 0., 0., 0.])

>>> out                   #  out的值被c.zero_()修改 !!
tensor([ 0., 0., 0.])

>>> out.sum().backward()  #  需要原來out得值,但是已經被c.zero_()覆蓋了,結果報錯
RuntimeError: one of the variables needed for gradient
computation has been modified by an

那么.detach()為什么是安全的?

使用.detach()的好處在于,若是出現上述情況,Autograd可以檢測出某一處變量已經發生了改變,進而以如下形式報錯,從而避免了錯誤的求導。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

從以上可以看出,是在前向傳播的過程中使用就地操作(In-place operation)導致了這一問題,那么就地操作是什么呢?

補充:pytorch中的detach()函數的作用

detach()

官方文檔中,對這個方法是這么介紹的。

  • 返回一個新的從當前圖中分離的 Variable。

  • 返回的 Variable 永遠不會需要梯度 如果 被 detach

  • 的Variable volatile=True, 那么 detach 出來的 volatile 也為 True

  • 還有一個注意事項,即:返回的 Variable 和 被 detach 的Variable 指向同一個 tensor

import torch
from torch.nn import init
from torch.autograd import Variable
t1 = torch.FloatTensor([1., 2.])
v1 = Variable(t1)
t2 = torch.FloatTensor([2., 3.])
v2 = Variable(t2)
v3 = v1 + v2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached)    # v3 中tensor 的值也會改變

能用來干啥

可以對部分網絡求梯度。

如果我們有兩個網絡 , 兩個關系是這樣的 現在我們想用 來為B網絡的參數來求梯度,但是又不想求A網絡參數的梯度。我們可以這樣:

# y=A(x), z=B(y) 求B中參數的梯度,不求A中參數的梯度
y = A(x)
z = B(y.detach())
z.backward()

關于“PyTorch中tensor.detach()和tensor.data的區別有哪些”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業相關的知識,可以關注億速云行業資訊頻道,小編每天都會為大家更新不同的知識點。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女