溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

PyTorch梯度裁剪如何避免訓練loss nan

發布時間:2022-02-24 09:41:17 來源:億速云 閱讀:307 作者:小新 欄目:開發技術

這篇文章主要為大家展示了“PyTorch梯度裁剪如何避免訓練loss nan”,內容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領大家一起研究并學習一下“PyTorch梯度裁剪如何避免訓練loss nan”這篇文章吧。

訓練代碼使用示例如下:

from torch.nn.utils import clip_grad_norm_
outputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
# clip the grad
clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()

其中,max_norm為梯度的最大范數,也是梯度裁剪時主要設置的參數。

備注:網上有同學提醒在(強化學習)使用了梯度裁剪之后訓練時間會大大增加。目前在我的檢測網絡訓練中暫時還沒有碰到這個問題,以后遇到再來更新。

補充:pytorch訓練過程中出現nan的排查思路

1、最常見的就是出現了除0或者log0這種

看看代碼中在這種操作的時候有沒有加一個很小的數,但是這個數數量級要和運算的數的數量級要差很多。一般是1e-8。

2、在optim.step()之前裁剪梯度

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
optim.step()

max_norm一般是1,3,5。

3、前面兩條還不能解決nan的話

就按照下面的流程來判斷。

...
loss = model(input)
# 1. 先看loss是不是nan,如果loss是nan,那么說明可能是在forward的過程中出現了第一條列舉的除0或者log0的操作
assert torch.isnan(loss).sum() == 0, print(loss)
optim.zero_grad()
loss.backward()
# 2. 如果loss不是nan,那么說明forward過程沒問題,可能是梯度爆炸,所以用梯度裁剪試試
nn.utils.clip_grad_norm(model.parameters, max_norm, norm_type=2)
# 3.1 在step之前,判斷參數是不是nan, 如果不是判斷step之后是不是nan
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
optim.step()
# 3.2 在step之后判斷,參數和其梯度是不是nan,如果3.1不是nan,而3.2是nan,
# 特別是梯度出現了Nan,考慮學習速率是否太大,調小學習速率或者換個優化器試試。
assert torch.isnan(model.mu).sum() == 0, print(model.mu)
assert torch.isnan(model.mu.grad).sum() == 0, print(model.mu.grad)

以上是“PyTorch梯度裁剪如何避免訓練loss nan”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女