pytorch多进程训练
sixwalter Lv6

pytorch多进程训练

一. 基础知识

  • group:进程组,默认情况下只有一个组,多进程中一个group有多个world
  • world:全局进程个数
  • rank:表示进程号,用于进程间通信(值越低,优先级越高)
  • local-rank:进程内GPU编号

二. 开启分布式模式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def init_distributed_mode(args):
# 函数里的每一行代码都会在每个进程上单独执行
log = [] # 记录环境信息
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
log.extend([args.rank,args.world_size,args.gpu])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
log.extend([args.rank,args.gpu])
else:
print('Not using distributed mode')
args.distributed = False
return

print(f"args.gpu:{args.gpu}")
args.distributed = True
# 打印环境信息
print(f"environment info: {log}")

torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
# 分布式初始化
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# 同步所有进程
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
  • Post title:pytorch多进程训练
  • Post author:sixwalter
  • Create time:2023-08-05 11:14:26
  • Post link:https://coelien.github.io/2023/08/05/projects/huawei project/distributed/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments