# PyTorch怎么實現對貓狗二分類訓練集進行讀取
## 1. 引言
在深度學習領域,圖像分類是最基礎且應用最廣泛的任務之一。貓狗分類作為經典的二分類問題,常被用于教學和算法驗證。PyTorch作為當前最流行的深度學習框架之一,提供了完整的工具鏈來處理這類任務。本文將詳細介紹如何使用PyTorch實現貓狗二分類訓練集的讀取,涵蓋從數據準備到模型訓練的全流程。
### 1.1 為什么選擇PyTorch
PyTorch具有以下優勢:
- 動態計算圖(Dynamic Computation Graph)
- 簡潔直觀的API設計
- 活躍的社區支持
- 與Python生態完美融合
- 完善的GPU加速支持
### 1.2 文章結構
本文將按照以下邏輯展開:
1. 數據集準備與目錄結構
2. PyTorch數據讀取核心組件
3. 自定義數據集類實現
4. 數據增強與預處理
5. 數據加載器配置
6. 完整代碼示例
7. 常見問題與解決方案
---
## 2. 數據集準備與目錄結構
### 2.1 獲取標準數據集
推薦使用Kaggle的"Dogs vs Cats"數據集:
```bash
kaggle competitions download -c dogs-vs-cats
解壓后應包含以下結構:
data/
├── train/
│ ├── cat.0.jpg
│ ├── cat.1.jpg
│ ├── ...
│ ├── dog.0.jpg
│ ├── dog.1.jpg
│ └── ...
└── test/
├── 0.jpg
├── 1.jpg
└── ...
建議采用以下規范結構:
custom_data/
├── train/
│ ├── cat/
│ │ ├── cat001.jpg
│ │ └── ...
│ └── dog/
│ ├── dog001.jpg
│ └── ...
└── val/
├── cat/
└── dog/
典型配置: - 訓練集:20,000張(貓狗各10,000) - 驗證集:5,000張(貓狗各2,500) - 測試集:12,500張(無標簽)
基類定義:
class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
...
def __len__(self) -> int:
...
關鍵參數:
DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=False
)
常用變換:
transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
class CatDogDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = ['cat', 'dog']
self.samples = []
for class_idx, class_name in enumerate(self.classes):
class_dir = os.path.join(root_dir, class_name)
for img_name in os.listdir(class_dir):
self.samples.append(
(os.path.join(class_dir, img_name), class_idx)
)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
class CachedCatDogDataset(CatDogDataset):
def __init__(self, root_dir, transform=None, cache_size=1000):
super().__init__(root_dir, transform)
self.cache = {}
self.cache_size = cache_size
def __getitem__(self, idx):
if idx in self.cache:
return self.cache[idx]
img, label = super().__getitem__(idx)
if len(self.cache) < self.cache_size:
self.cache[idx] = (img, label)
return img, label
from torchvision import transforms
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
from albumentations import (
HorizontalFlip, Rotate, RandomBrightnessContrast,
HueSaturationValue, Compose
)
import cv2
def albumentations_transform():
return Compose([
HorizontalFlip(p=0.5),
Rotate(limit=15, p=0.5),
RandomBrightnessContrast(p=0.2),
HueSaturationValue(
hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=20,
p=0.5
)
])
class AlbumentationsDataset(CatDogDataset):
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
return image, label
from torch.utils.data import DataLoader
train_dataset = CatDogDataset(
root_dir='data/train',
transform=train_transform
)
val_dataset = CatDogDataset(
root_dir='data/val',
transform=val_transform
)
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=False,
num_workers=4
)
def collate_fn(batch):
images = [item[0] for item in batch]
labels = [item[1] for item in batch]
return torch.stack(images), torch.tensor(labels)
loader = DataLoader(
dataset,
batch_size=64,
collate_fn=collate_fn
)
from torch.utils.data.sampler import WeightedRandomSampler
class_counts = [10000, 10000] # 貓狗樣本數
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
samples_weights = weights[labels]
sampler = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights),
replacement=True
)
loader = DataLoader(
dataset,
batch_size=64,
sampler=sampler
)
import os
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
# 1. 定義數據集類
class CatDogDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.classes = ['cat', 'dog']
self.samples = []
for class_idx, class_name in enumerate(self.classes):
class_dir = os.path.join(root_dir, class_name)
for img_name in os.listdir(class_dir):
self.samples.append(
(os.path.join(class_dir, img_name), class_idx)
)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
# 2. 定義數據變換
train_transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 3. 創建數據集和數據加載器
train_dataset = CatDogDataset('data/train', train_transform)
val_dataset = CatDogDataset('data/val', val_transform)
train_loader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=False,
num_workers=4
)
# 4. 定義模型
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# 5. 訓練循環
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
model.train()
train_loss = 0.0
for inputs, labels in tqdm(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
# 驗證循環
model.eval()
val_loss = 0.0
correct = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, 1)
correct += torch.sum(preds == labels.data)
print(f'Epoch {epoch+1}: '
f'Train Loss: {train_loss/len(train_loader):.4f} '
f'Val Loss: {val_loss/len(val_loader):.4f} '
f'Val Acc: {correct.double()/len(val_dataset):.4f}')
癥狀: - 出現”CUDA out of memory”錯誤 - 訓練過程頻繁崩潰
解決方案:
1. 減小batch_size
(如從64降到32)
2. 使用梯度累積:
accumulation_steps = 4
for i, (inputs, labels) in enumerate(train_loader):
outputs = model(inputs)
loss = criterion(outputs, labels) / accumulation_steps
loss.backward()
if (i+1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
優化方案:
1. 增加num_workers
(通常設為CPU核心數的2-4倍)
2. 啟用pin_memory=True
3. 使用更快的存儲(如NVMe SSD)
4. 預加載部分數據:
from prefetch_generator import BackgroundGenerator
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super().__iter__())
處理方法: 1. 加權損失函數:
class_weights = torch.tensor([1.0, 2.0]) # 假設狗樣本較少
criterion = nn.CrossEntropyLoss(weight=class_weights)
解決方案:
1. 統一resize到固定尺寸
2. 使用RandomResizedCrop
增強魯棒性
3. 動態padding:
transforms.Compose([
transforms.Resize(256),
transforms.Pad(0, fill=0, padding_mode='constant'),
transforms.CenterCrop(224),
...
])
本文詳細介紹了PyTorch實現貓狗二分類訓練集讀取的完整流程,關鍵點包括:
通過靈活運用PyTorch提供的數據處理工具,我們可以高效地構建適合深度學習訓練的數據管道。良好的數據讀取實現不僅能提升訓練效率,還能通過有效的數據增強提升模型泛化能力。
注意:實際運行時請根據硬件條件調整batch_size和num_workers等參數,完整代碼約200行,建議在Jupyter Notebook中分步執行測試。 “`
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。