-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdist_utils.py
138 lines (103 loc) · 4 KB
/
dist_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# some code in this file is adapted from
# https://github.com/facebookresearch/moco
# Original Copyright 2020 Facebook, Inc. and its affiliates. Licensed under the CC-BY-NC 4.0 License.
# Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: CC-BY-NC-4.0
import torch
import torch.distributed as dist
import os
@torch.no_grad()
def batch_shuffle_ddp(x):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).cuda()
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@torch.no_grad()
def batch_unshuffle_ddp(x, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def dist_init(port=23456):
def init_parrots(host_addr, rank, local_rank, world_size, port):
os.environ['MASTER_ADDR'] = str(host_addr)
os.environ['MASTER_PORT'] = str(port)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
def init(host_addr, rank, local_rank, world_size, port):
host_addr_full = 'tcp://' + host_addr + ':' + str(port)
torch.distributed.init_process_group("nccl", init_method=host_addr_full,
rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
assert torch.distributed.is_initialized()
def parse_host_addr(s):
if '[' in s:
left_bracket = s.index('[')
right_bracket = s.index(']')
prefix = s[:left_bracket]
first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0]
return prefix + first_number
else:
return s
rank = int(os.environ['SLURM_PROCID'])
local_rank = int(os.environ['SLURM_LOCALID'])
world_size = int(os.environ['SLURM_NTASKS'])
ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST'])
if torch.__version__ == 'parrots':
init_parrots(ip, rank, local_rank, world_size, port)
from parrots import config
config.set_attr('engine', 'timeout', value=500)
else:
init(ip, rank, local_rank, world_size, port)
return rank, local_rank, world_size