溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch凍結某層參數的實現方法

發布時間:2021-02-01 14:15:08 來源:億速云 閱讀:652 作者:小新 欄目:開發技術

這篇文章主要介紹了pytorch凍結某層參數的實現方法,具有一定借鑒價值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。

在遷移學習finetune時我們通常需要凍結前幾層的參數不參與訓練,在Pytorch中的實現如下:

class Model(nn.Module):
 def __init__(self):
  super(Transfer_model, self).__init__()
  self.linear1 = nn.Linear(20, 50)
  self.linear2 = nn.Linear(50, 20)
  self.linear3 = nn.Linear(20, 2)

 def forward(self, x):
 pass

假如我們想要凍結linear1層,需要做如下操作:

model = Model()
# 這里是一般情況,共享層往往不止一層,所以做一個for循環
for para in model.linear1.parameters():
 para.requires_grad = False
# 假如真的只有一層也可以這樣操作:
# model.linear1.weight.requires_grad = False

 最后我們需要將需要優化的參數傳入優化器,不需要傳入的參數過濾掉,所以要用到filter()函數。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

其它的博客中都沒有講解filter()函數的作用,在這里我簡單講一下有助于更好的理解。

filter(function, iterable)

  • function: 判斷函數

  • iterable: 可迭代對象

filter() 函數用于過濾序列,過濾掉不符合條件的元素,返回一個迭代器對象,如果要轉換為列表,可以使用 list() 來轉換。

該接收兩個參數,第一個為函數,第二個為序列,序列的每個元素作為參數傳遞給函數進行判,然后返回 True 或 False,最后將返回 True 的元素放到新列表中。

filter()函數將requires_grad = True的參數傳入優化器進行反向傳播,requires_grad = False的則被過濾掉。

感謝你能夠認真閱讀完這篇文章,希望小編分享的“pytorch凍結某層參數的實現方法”這篇文章對大家有幫助,同時也希望大家多多支持億速云,關注億速云行業資訊頻道,更多相關知識等著你來學習!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女