在Linux下進行PyTorch的分布式訓練,通常需要以下幾個步驟:
環境準備:
啟動分布式訓練:
PyTorch提供了torch.distributed.launch
工具來簡化分布式訓練的啟動過程。以下是一個基本的命令行示例:
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE --nnodes=NUM_NODES --node_rank=NODE_RANK --master_addr=MASTER_NODE_IP --master_port=12345 YOUR_TRAINING_SCRIPT.py
參數說明:
--nproc_per_node
:每個節點上使用的GPU數量。--nnodes
:總的節點數。--node_rank
:當前節點的排名(從0開始)。--master_addr
:主節點的IP地址。--master_port
:主節點上用于通信的端口號。YOUR_TRAINING_SCRIPT.py
:你的訓練腳本。修改訓練腳本: 在你的訓練腳本中,需要初始化分布式環境。通常在腳本的最開始添加以下代碼:
import torch.distributed as dist
dist.init_process_group(
backend='nccl', # 'nccl' is recommended for distributed GPU training
init_method='tcp://MASTER_NODE_IP:12345',
world_size=NUM_GPUS_YOU_HAVE * NUM_NODES,
rank=NODE_RANK
)
參數說明:
backend
:分布式后端,對于GPU訓練推薦使用nccl
。init_method
:初始化分布式環境的地址。world_size
:總的進程數,等于GPU數量乘以節點數。rank
:當前進程的排名。數據并行:
在你的訓練循環中,使用torch.nn.parallel.DistributedDataParallel
來包裝你的模型:
model = YourModel().to(device)
ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
參數說明:
device_ids
:指定當前進程使用的GPU ID。運行訓練: 在每個節點上運行修改后的訓練腳本,確保所有節點都使用相同的命令行參數。
以下是一個完整的示例:
# 在主節點上運行
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=192.168.1.1 --master_port=12345 train.py
# 在其他節點上運行
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=192.168.1.1 --master_port=12345 train.py
在train.py
中:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
dist.init_process_group(
backend='nccl',
init_method='tcp://192.168.1.1:12345',
world_size=8,
rank=0 # 這個rank會在每個節點上變化
)
device = torch.device(f"cuda:{dist.get_rank()}")
model = YourModel().to(device)
ddp_model = DDP(model, device_ids=[dist.get_rank()])
# 訓練循環
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = ddp_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
if __name__ == "__main__":
main()
通過以上步驟,你可以在Linux下進行PyTorch的分布式訓練。確保所有節點的網絡配置正確,并且防火墻允許相應的端口通信。