PyTorch 中實現早停(Early Stopping)的方法有很多種,其中一種常見的方法是使用驗證集(validation set)來監控模型的性能,當驗證集上的性能不再提高時,停止訓練。下面是一個簡單的實現方法:
from torch.utils.data import DataLoader, random_split
# 假設 train_dataset 是你的訓練數據集,val_dataset 是你的驗證數據集
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
def compute_metrics(model, val_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
return {'accuracy': accuracy}
class EarlyStopping:
def __init__(self, patience=10, verbose=False):
self.patience = patience
self.verbose = verbose
self.best_score = float('-inf')
self.wait = 0
def step(self, model, val_loader):
score = compute_metrics(model, val_loader)['accuracy']
if score > self.best_score:
self.best_score = score
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
if self.verbose:
print(f'Early stopping at epoch {epoch}, best accuracy: {self.best_score:.2f}')
return True
return False
early_stopping = EarlyStopping(patience=10, verbose=True)
for epoch in range(num_epochs):
model.train()
# 訓練代碼
model.eval()
should_stop = early_stopping.step(model, val_loader)
if should_stop:
break
在這個例子中,我們定義了一個 EarlyStopping
類,它會在驗證集上的性能指標不再提高時停止訓練。我們在訓練循環中使用這個類,并在每個 epoch 結束時調用它的 step
方法。如果 step
方法返回 True
,則表示應該停止訓練。