在PyTorch中,對MNIST分類任務進行權重初始化,可以使用以下方法:
torch.nn.init
模塊中的預定義函數。例如,使用Xavier初始化(也稱為Glorot初始化)或He初始化。這些初始化方法有助于在訓練初期加速收斂。import torch.nn as nn
import torch.nn.init as init
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
init.kaiming_uniform_(m.weight, nonlinearity='relu')
init.zeros_(m.bias)
torch.nn.init
模塊中的normal_
函數,并設置std
參數。例如,可以設置權重標準差為0.05。def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
init.normal_(m.weight, mean=0, std=0.05)
init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
init.normal_(m.weight, mean=0, std=0.05)
init.zeros_(m.bias)
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Linear):
# 自定義線性層權重初始化
init.uniform_(m.weight, -1, 1)
init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
# 自定義卷積層權重初始化
init.kaiming_uniform_(m.weight, nonlinearity='relu')
init.zeros_(m.bias)
在定義好權重初始化函數后,可以在創建模型實例后調用該函數,以確保權重被正確初始化。
model = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
initialize_weights(model)