溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch中如何使用遷移學習resnet18訓練mnist數據集

發布時間:2021-12-04 19:07:29 來源:億速云 閱讀:589 作者:柒染 欄目:大數據
# PyTorch中如何使用遷移學習ResNet18訓練MNIST數據集

## 1. 遷移學習概述

遷移學習(Transfer Learning)是深度學習中的重要技術,它允許我們將在一個任務上訓練好的模型參數遷移到另一個相關任務中。這種方法特別適用于以下場景:
- 目標數據集較?。ㄈ玑t學圖像)
- 計算資源有限
- 需要快速原型開發

在計算機視覺領域,預訓練的CNN模型(如ResNet、VGG等)通過遷移學習可以顯著提升在小規模數據集上的表現。本文將詳細介紹如何使用PyTorch中的ResNet18模型,通過遷移學習技術來訓練MNIST手寫數字數據集。

## 2. 環境準備與數據加載

### 2.1 安裝必要庫

```python
!pip install torch torchvision matplotlib

2.2 導入所需模塊

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

2.3 數據預處理

MNIST圖像是28x28的灰度圖像,而ResNet18默認輸入是224x224的3通道圖像,需要進行調整:

# 定義數據轉換
transform = transforms.Compose([
    transforms.Resize(224),  # 調整大小
    transforms.Grayscale(num_output_channels=3),  # 灰度轉RGB
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 歸一化
])

# 加載數據集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 創建數據加載器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

3. 模型準備與遷移學習策略

3.1 加載預訓練ResNet18

model = models.resnet18(pretrained=True)

3.2 模型結構調整

ResNet18原始輸出是1000類(ImageNet),我們需要修改最后一層以適應MNIST的10分類任務:

# 凍結所有卷積層參數
for param in model.parameters():
    param.requires_grad = False

# 替換最后的全連接層
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # MNIST有10個類別

3.3 遷移學習策略選擇

常見遷移學習策略有: 1. 特征提取器:凍結卷積層,只訓練全連接層(本文采用) 2. 微調:解凍部分或全部卷積層進行微調 3. 漸進解凍:逐步解凍網絡層

4. 訓練過程實現

4.1 定義訓練函數

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / len(train_loader.dataset)
    
    print(f'Train Epoch: {epoch} \tLoss: {train_loss:.4f} \tAccuracy: {train_acc:.2f}%')
    return train_loss, train_acc

4.2 定義測試函數

def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader)
    test_acc = 100. * correct / len(test_loader.dataset)
    
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%\n')
    return test_loss, test_acc

4.3 主訓練循環

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

epochs = 10
train_losses, train_accs = [], []
test_losses, test_accs = [], []

for epoch in range(1, epochs + 1):
    train_loss, train_acc = train(model, device, train_loader, optimizer, criterion, epoch)
    test_loss, test_acc = test(model, device, test_loader, criterion)
    
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

5. 結果可視化與分析

5.1 訓練過程可視化

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.show()

5.2 性能分析

典型訓練結果可能顯示: - 測試準確率可達98%以上 - 訓練曲線快速收斂 - 過擬合現象不明顯(得益于預訓練特征)

6. 進階優化策略

6.1 學習率調整

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

6.2 部分層微調

# 解凍最后兩個卷積塊
for name, param in model.named_parameters():
    if 'layer4' in name or 'layer3' in name:
        param.requires_grad = True

6.3 數據增強

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

7. 常見問題與解決方案

7.1 輸入維度不匹配

問題:ResNet期望3通道輸入,MNIST是單通道
解決:使用Grayscale(num_output_channels=3)轉換

7.2 過擬合

現象:訓練準確率高但測試準確率低
解決方案: - 增加數據增強 - 添加Dropout層 - 使用更小的學習率 - 早停(Early Stopping)

7.3 訓練速度慢

優化方案: - 使用更大的batch size - 啟用混合精度訓練 - 分布式訓練

8. 完整代碼示例

# 省略部分導入和函數定義...

def main():
    # 數據準備
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.Grayscale(3),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    train_set = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST('./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
    
    # 模型準備
    model = models.resnet18(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, 10)
    
    # 訓練配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
    
    # 訓練循環
    for epoch in range(1, 11):
        train(model, device, train_loader, optimizer, criterion, epoch)
        test(model, device, test_loader, criterion)

if __name__ == '__main__':
    main()

9. 總結

本文詳細介紹了在PyTorch中使用ResNet18進行遷移學習訓練MNIST數據集的全過程,關鍵點包括: 1. 正確處理單通道圖像的輸入適配 2. 合理凍結/解凍網絡層 3. 針對小數據集的訓練技巧 4. 模型性能評估與優化

遷移學習大大降低了在特定任務上訓練模型的成本,即使像MNIST這樣簡單的數據集,使用預訓練模型也能獲得更好的特征表示和泛化能力。這種方法可以輕松擴展到其他類似任務中。 “`

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女