1. 降低批次大?。˙atch Size)
批次大小是影響GPU內存使用的核心因素之一。較小的批次會顯著減少單次迭代的內存占用,但需平衡其對訓練速度和模型性能的影響(如過小的批次可能導致收斂變慢)。建議通過實驗找到模型穩定性和內存占用的最佳平衡點。
2. 使用半精度浮點數(Half-Precision, float16)
通過**自動混合精度(AMP)**訓練,將計算從單精度(float32)切換至半精度(float16),可在保持數值穩定性的同時減少內存使用(約50%)。PyTorch的torch.cuda.amp
模塊提供了便捷支持:
scaler = torch.cuda.amp.GradScaler() # 用于梯度縮放,防止數值溢出
with torch.cuda.amp.autocast(): # 自動選擇float16/float32計算
output = model(input)
loss = criterion(output, target)
scaler.scale(loss).backward() # 縮放梯度以避免underflow
scaler.step(optimizer) # 更新參數
scaler.update() # 調整縮放因子
此方法尤其適用于大型模型(如Transformer、CNN)。
3. 啟用梯度累積(Gradient Accumulation)
若減小批次大小影響模型性能,可通過梯度累積模擬更大批次的效果。即在多個小批次上計算梯度并累加,最后再進行一次參數更新。示例代碼:
accum_steps = 4 # 累積4個小批次的梯度
optimizer.zero_grad()
for i, (inputs, labels) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, labels) / accum_steps # 平均損失
loss.backward() # 累積梯度
if (i + 1) % accum_steps == 0: # 達到累積步數后更新參數
optimizer.step()
optimizer.zero_grad()
該方法可在不增加顯存的情況下,提升有效批次大小。
4. 及時釋放無用內存
del
關鍵字刪除不再需要的張量(如中間結果、舊模型),釋放其占用的內存。torch.cuda.empty_cache()
清除PyTorch緩存的無用顯存塊(如已釋放的張量),避免內存碎片。gc.collect()
強制觸發垃圾回收,徹底釋放無引用的對象。del tensor_name # 刪除無用張量
torch.cuda.empty_cache() # 清空GPU緩存
import gc
gc.collect() # 垃圾回收
5. 優化數據加載流程
數據加載是內存瓶頸的常見來源。通過以下設置提升數據加載效率,減少內存占用:
num_workers
參數(如num_workers=4
),利用多核CPU并行讀取數據。pin_memory=True
,將數據預加載到固定內存(Pinned Memory),加速GPU傳輸。dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=32,
num_workers=4, # 根據CPU核心數調整
pin_memory=True # 加速GPU數據傳輸
)
6. 使用內存高效的模型結構
選擇或設計輕量級模型,減少參數數量和內存占用:
7. 利用分布式訓練
將模型訓練分布到多個GPU或多臺機器上,通過數據并行(DistributedDataParallel
,推薦)或模型并行減少單個設備的內存負載。DistributedDataParallel
(DDP)是PyTorch推薦的方式,支持多進程并行,效率高且內存占用低:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group(backend='nccl') # 初始化進程組
model = DDP(model.to(device)) # 包裝模型
需注意,分布式訓練需調整批次大?。偱?單卡批次×GPU數量)。
8. 監控內存使用
使用PyTorch內置工具監控內存占用,定位瓶頸:
torch.cuda.memory_summary()
顯示顯存分配詳情(如已用/剩余顯存、緩存狀態)。torch.cuda.memory_allocated()
返回當前分配的顯存大小,torch.cuda.max_memory_allocated()
返回歷史最大顯存占用。torch.profiler
分析內存使用情況,識別高內存消耗的操作(如特定層的張量分配)。print(torch.cuda.memory_summary()) # 打印顯存摘要
print(f"當前顯存占用: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f"最大顯存占用: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
9. 系統級優化
sudo echo 3 | sudo tee /proc/sys/vm/drop_caches
,釋放系統緩存(不影響PyTorch已分配的顯存)。