# CentOS系統
sudo yum update -y
sudo yum install -y gcc-c++ make cmake git
# Ubuntu系統
sudo apt-get update && sudo apt-get install -y build-essential cmake git
ping <節點IP>測試);# CentOS系統
sudo firewall-cmd --zone=public --add-port=23456/tcp --permanent
sudo firewall-cmd --reload
# Ubuntu系統(ufw)
sudo ufw allow 23456/tcp
ssh-keygen -t rsa # 直接回車,默認保存路徑~/.ssh/id_rsa
ssh-copy-id user@worker1_ip # 替換為工作節點用戶名和IP
ssh-copy-id user@worker2_ip
ssh user@worker1_ip # 無需輸入密碼即可登錄
pip3 install torch torchvision torchaudio
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
import torch
print(torch.cuda.is_available()) # 應輸出True
pip3 install dask distributed # Dask安裝
pip3 install mpi4py
在訓練腳本中,使用torch.distributed.init_process_group初始化分布式環境:
import torch.distributed as dist
def setup(rank, world_size):
# 初始化進程組,推薦使用NCCL后端(GPU加速)
dist.init_process_group(
backend='nccl',
init_method='tcp://<master_ip>:<master_port>', # 主節點IP和端口
world_size=world_size, # 總進程數(節點數×每個節點GPU數)
rank=rank # 當前進程的全局排名(0到world_size-1)
)
def cleanup():
dist.destroy_process_group() # 訓練結束后銷毀進程組
使用torch.nn.parallel.DistributedDataParallel(DDP)包裝模型,實現數據并行:
import torch.nn as nn
model = YourModel().to(rank) # 將模型移動到當前GPU
ddp_model = nn.parallel.DistributedDataParallel(
model,
device_ids=[rank] # 當前節點的GPU索引
)
使用DistributedSampler確保每個進程加載不同的數據子集:
from torch.utils.data import DataLoader, DistributedSampler
dataset = YourDataset() # 自定義數據集
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # 總進程數
rank=rank # 當前進程排名
)
dataloader = DataLoader(
dataset,
batch_size=32,
sampler=sampler # 使用分布式采樣器
)
在訓練循環中,每輪迭代前調用sampler.set_epoch(epoch),確保數據打亂順序:
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # 每輪重置采樣器,避免數據重復
for data, target in dataloader:
data, target = data.to(rank), target.to(rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
在主節點運行以下命令,啟動分布式訓練:
python -m torch.distributed.launch \
--nproc_per_node=<num_gpus> \ # 每個節點的GPU數量(如4)
--nnodes=<total_nodes> \ # 總節點數(如2)
--node_rank=<current_node_rank> \ # 當前節點排名(主節點0,工作節點1、2...)
--master_addr=<master_ip> \ # 主節點IP
--master_port=<port> \ # 主節點端口(如23456)
your_training_script.py # 訓練腳本路徑
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="192.168.1.100" --master_port=23456 train.py
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="192.168.1.100" --master_port=23456 train.py
編寫啟動腳本start_train.sh,自動分發命令到各節點:
#!/bin/bash
# 主節點IP和端口
MASTER_ADDR="192.168.1.100"
MASTER_PORT=23456
# 總節點數
NNODES=2
# 每個節點的GPU數量
GPUS_PER_NODE=4
# 主節點運行
if [ "$1" == "master" ]; then
echo "Starting master node..."
python -m torch.distributed.launch \
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$NNODES \
--node_rank=0 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py
# 工作節點運行
elif [ "$1" == "worker" ]; then
echo "Starting worker node..."
python -m torch.distributed.launch \
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$NNODES \
--node_rank=1 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py
else
echo "Usage: $0 {master|worker}"
fi
chmod +x start_train.sh
./start_train.sh master
./start_train.sh worker
rank和world_size:print(f"Rank {rank}, World Size {world_size}")
rank(0到world_size-1)。http://<master_ip>:8787查看任務進度和集群狀態;torch.profiler監控訓練性能,識別瓶頸:with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3),
on_trace_ready=lambda prof: prof.export_chrome_trace("trace.json")
) as prof:
for data, target in dataloader:
# 訓練代碼
prof.step()
# CentOS系統
sudo yum install -y ntp
sudo systemctl start ntpd
sudo systemctl enable ntpd
# Ubuntu系統
sudo apt-get install -y ntp
sudo systemctl start ntp
sudo systemctl enable ntp
requirements.txt或environment.yml統一所有節點的Python環境:pip freeze > requirements.txt # 導出主節點環境
# 在工作節點安裝相同環境
pip install -r requirements.txt
--nproc_per_node,避免資源浪費;world_size等于“節點數×每個節點GPU數”。ping <節點IP>;ssh user@worker_ip;train.log),定位錯誤信息。