Ubuntu下PyTorch內存優化方法
批次大小是影響GPU顯存使用的核心因素之一。較小的批次大小能直接減少顯存占用,但需平衡其對訓練速度(如梯度更新頻率)和模型性能(如泛化能力)的影響。建議通過實驗找到“顯存占用可接受且不影響模型效果”的最小批次值。
半精度(float16)相比單精度(float32)可減少50%的顯存占用,同時通過PyTorch的torch.cuda.amp
模塊實現自動混合精度(AMP),能在保持模型數值穩定性的前提下,自動在float16和float32之間切換(如梯度計算用float32保證穩定性,前向/反向傳播用float16提升速度)。示例代碼:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for inputs, targets in dataloader:
optimizer.zero_grad()
with autocast(): # 自動混合精度上下文
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward() # 縮放梯度避免underflow
scaler.step(optimizer) # 更新參數
scaler.update() # 調整縮放因子
del
關鍵字刪除不再需要的中間變量(如損失值、預測結果),斷開其對顯存的引用。torch.cuda.empty_cache()
釋放PyTorch緩存的無用顯存(如未使用的中間結果),注意該操作不會釋放被引用的張量。gc.collect()
強制Python垃圾回收器回收無用對象,配合del
和empty_cache()
效果更佳。del outputs, loss # 刪除無用變量
torch.cuda.empty_cache() # 清空GPU緩存
import gc
gc.collect() # 觸發垃圾回收
梯度累積通過“多次小批次計算梯度→累加→一次更新”的方式,模擬更大批次的效果,同時不增加顯存占用。適用于“顯存不足但需較大批次”的場景。示例代碼:
accumulation_steps = 4 # 累積4個小批次的梯度
for i, (inputs, targets) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, targets)
loss = loss / accumulation_steps # 歸一化損失(避免梯度爆炸)
loss.backward() # 累積梯度
if (i + 1) % accumulation_steps == 0: # 達到累積步數時更新參數
optimizer.step()
optimizer.zero_grad() # 清零梯度
DataLoader
的num_workers
參數(設置為CPU核心數的2-4倍)啟用多進程數據加載,避免數據預處理成為瓶頸。in_features
與out_features
轉換為卷積核的in_channels
與out_channels
),減少參數數量(如ResNet-50的全連接層參數占比約90%)。將模型訓練分布到多個GPU(單機多卡)或多臺機器(多機多卡),通過數據并行(Data Parallelism)或模型并行(Model Parallelism)減少單個設備的顯存負載。PyTorch提供torch.distributed
模塊支持分布式訓練,常用 launch 工具如torchrun
。示例代碼(數據并行):
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group(backend='nccl') # 初始化進程組
model = DDP(model.to(device)) # 包裝模型
torch.cuda.memory_summary()
打印顯存分配詳情(如已用/剩余顯存、緩存情況),或nvidia-smi
命令實時監控GPU顯存使用率。torch.utils.checkpoint
檢查張量是否意外保留計算圖(如非訓練場景未用with torch.no_grad()
),或使用memory_profiler
庫逐行跟蹤內存變化。sudo echo 3 | sudo tee /proc/sys/vm/drop_caches
釋放系統頁面緩存(不影響正在運行的程序)。sudo dd if=/dev/zero of=/swapfile bs=64M count=16
,sudo mkswap /swapfile
,sudo swapon /swapfile
),作為物理內存的擴展(注意:Swap性能低于物理內存,僅作臨時解決方案)。