溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

基于MSELoss()與CrossEntropyLoss()的區別詳解

發布時間:2020-10-13 00:48:26 來源:腳本之家 閱讀:280 作者:Foneone 欄目:開發技術

基于pytorch來講

MSELoss()多用于回歸問題,也可以用于one_hotted編碼形式,

CrossEntropyLoss()名字為交叉熵損失函數,不用于one_hotted編碼形式

MSELoss()要求batch_x與batch_y的tensor都是FloatTensor類型

CrossEntropyLoss()要求batch_x為Float,batch_y為LongTensor類型

(1)CrossEntropyLoss() 舉例說明:

比如二分類問題,最后一層輸出的為2個值,比如下面的代碼:

class CNN (nn.Module ) :
  def __init__ ( self , hidden_size1 , output_size , dropout_p) :
    super ( CNN , self ).__init__ ( )
    self.hidden_size1 = hidden_size1
    self.output_size = output_size
    self.dropout_p = dropout_p
    
    self.conv1 = nn.Conv1d ( 1,8,3,padding =1) 
    self.fc1 = nn.Linear (8*500, self.hidden_size1 )
    self.out = nn.Linear (self.hidden_size1,self.output_size ) 
 
  
  def forward ( self , encoder_outputs ) :
    cnn_out = F.max_pool1d ( F.relu (self.conv1(encoder_outputs)),2) 
    cnn_out = F.dropout ( cnn_out ,self.dropout_p) #加一個dropout
    cnn_out = cnn_out.view (-1,8*500) 
    output_1 = torch.tanh ( self.fc1 ( cnn_out ) )
    output = self.out ( ouput_1)
    return output

最后的輸出結果為:

基于MSELoss()與CrossEntropyLoss()的區別詳解

上面一個tensor為output結果,下面為target,沒有使用one_hotted編碼。

訓練過程如下:

cnn_optimizer = torch.optim.SGD(cnn.parameters(),learning_rate,momentum=0.9,\
              weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
 
def train ( input_variable , target_variable , cnn , cnn_optimizer , criterion ) :
  cnn_output = cnn( input_variable )
  print(cnn_output)
  print(target_variable)
  loss = criterion ( cnn_output , target_variable)
  cnn_optimizer.zero_grad ()
  loss.backward( )
  cnn_optimizer.step( )
  #print('loss: ',loss.item())
  return loss.item() #返回損失

說明CrossEntropyLoss()是output兩位為one_hotted編碼形式,但target不是one_hotted編碼形式。

(2)MSELoss() 舉例說明:

網絡結構不變,但是標簽是one_hotted編碼形式。下面的圖僅做說明,網絡結構不太對,出來的預測也不太對。

基于MSELoss()與CrossEntropyLoss()的區別詳解

如果target不是one_hotted編碼形式會報錯,報的錯誤如下。

基于MSELoss()與CrossEntropyLoss()的區別詳解

目前自己理解的兩者的區別,就是這樣的,至于多分類問題是不是也是樣的有待考察。

以上這篇基于MSELoss()與CrossEntropyLoss()的區別詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女