在PyTorch中,正則化是一種常用的技術,用于防止模型過擬合。常見的正則化方法包括L1正則化和L2正則化。
在PyTorch中,可以使用nn.Module
的add_weight()
方法為模型參數添加正則化項。例如,以下代碼為模型的權重添加了L2正則化項:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
self.fc1 = nn.Linear(128 * 25 * 25, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 10)
# 添加L2正則化項
for param in self.parameters():
param.requires_grad = True
param.register_hook(lambda x: x * (1 - 0.001))
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 128 * 25 * 25)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
在上面的代碼中,我們使用了register_hook()
方法為每個參數添加了一個鉤子函數,該函數將參數乘以一個因子(在這里是1 - 0.001
),從而實現了L2正則化。
除了L2正則化外,還可以使用其他正則化方法,例如L1正則化和Dropout。在PyTorch中,這些方法也可以很容易地實現。