Skip to content

Commit

Permalink
use python to launch
Browse files Browse the repository at this point in the history
  • Loading branch information
Sikan Li committed Jul 3, 2024
1 parent 1a8546c commit 7f31242
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 46 deletions.
50 changes: 29 additions & 21 deletions gns/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,40 @@
import torch.distributed as dist
from torch.utils import collect_env
from torch.utils.data.distributed import DistributedSampler
import os


def setup(local_rank: int):
"""Initializes distributed training."""
# Initialize group, blocks until all processes join.
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
)
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(0)
verbose = dist.get_rank() == 0

if verbose:
print("Collecting env info...")
print(collect_env.get_pretty_env_info())
print()

for r in range(torch.distributed.get_world_size()):
if r == torch.distributed.get_rank():
print(
f"Global rank {torch.distributed.get_rank()} initialized: "
f"local_rank = {local_rank}, "
f"world_size = {torch.distributed.get_world_size()}",
)
if "LOCAL_RANK" in os.environ:
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
)
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(0)
verbose = dist.get_rank() == 0

if verbose:
print("Collecting env info...")
print(collect_env.get_pretty_env_info())
print()

for r in range(torch.distributed.get_world_size()):
if r == torch.distributed.get_rank():
print(
f"Global rank {torch.distributed.get_rank()} initialized: "
f"local_rank = {local_rank}, "
f"world_size = {torch.distributed.get_world_size()}",
)
else:
verbose = True
world_size = torch.cuda.device_count()
print(
f"world_size = {world_size}",
)
return verbose, world_size


Expand Down
6 changes: 3 additions & 3 deletions gns/particle_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_data_loader(
input_sequence_length=6,
batch_size=32,
shuffle=True,
is_distributed=False,
use_dist=False,
):
"""
Get a data loader for the ParticleDataset.
Expand All @@ -189,14 +189,14 @@ def get_data_loader(
input_sequence_length (int): Length of input sequence.
batch_size (int): Batch size for the data loader.
shuffle (bool): Whether to shuffle the data.
is_distributed (bool): Whether to use DistributedSampler for distributed training.
use_dist (bool): Whether to use DistributedSampler for distributed training.
Returns:
DataLoader: A PyTorch DataLoader object.
"""
dataset = ParticleDataset(file_path, input_sequence_length, mode)

if is_distributed:
if use_dist:
sampler = DistributedSampler(dataset, shuffle=shuffle)
shuffle = False # DistributedSampler handles shuffling
else:
Expand Down
53 changes: 31 additions & 22 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def save_model_and_train_state(
valid_loss,
train_loss_hist,
valid_loss_hist,
use_dist,
):
"""Save model state
Expand All @@ -270,7 +271,7 @@ def save_model_and_train_state(
valid_loss_hist: validation loss history at each epoch
"""
if verbose:
if device == torch.device("cpu"):
if not use_dist:
simulator.save(cfg.model.path + "model-" + str(step) + ".pt")
else:
simulator.module.save(cfg.model.path + "model-" + str(step) + ".pt")
Expand All @@ -288,7 +289,7 @@ def save_model_and_train_state(
torch.save(train_state, f"{cfg.model.path}train_state-{step}.pt")


def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device):
def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device, use_dist):
"""Setup simulator and optimizer.
Args:
Expand All @@ -306,7 +307,10 @@ def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device):
cfg.data.noise_std,
rank,
)
simulator = DDP(serial_simulator.to("cuda"), device_ids=[rank])
if use_dist:
simulator = DDP(serial_simulator.to("cuda"), device_ids=[rank])
else:
simulator = serial_simulator.to("cuda")
optimizer = torch.optim.Adam(
simulator.parameters(), lr=cfg.training.learning_rate.initial * world_size
)
Expand All @@ -324,30 +328,31 @@ def setup_simulator_and_optimizer(cfg, metadata, rank, world_size, device):
return simulator, optimizer


def initialize_training(cfg, rank, world_size, device):
def initialize_training(cfg, rank, world_size, device, use_dist):
"""Initialize training.
Args:
cfg: Configuration dictionary.
rank: Local rank.
world_size: Total number of ranks.
device: torch device type.
use_dist: use torch.distribute
"""
metadata = reading_utils.read_metadata(cfg.data.path, "train")
simulator, optimizer = setup_simulator_and_optimizer(
cfg, metadata, rank, world_size, device
cfg, metadata, rank, world_size, device, use_dist
)
return simulator, optimizer, metadata


def load_datasets(cfg, is_distributed):
def load_datasets(cfg, use_dist):
# Train data loader
train_dl = pdl.get_data_loader(
file_path=f"{cfg.data.path}train.npz",
mode="sample",
input_sequence_length=cfg.data.input_sequence_length,
batch_size=cfg.data.batch_size,
is_distributed=is_distributed,
use_dist=use_dist,
)
train_dataset = pdl.ParticleDataset(f"{cfg.data.path}train.npz")
n_features = train_dataset.get_num_features()
Expand All @@ -360,7 +365,7 @@ def load_datasets(cfg, is_distributed):
mode="sample",
input_sequence_length=cfg.data.input_sequence_length,
batch_size=cfg.data.batch_size,
is_distributed=is_distributed,
use_dist=use_dist,
)
valid_dataset = pdl.ParticleDataset(f"{cfg.data.path}valid.npz")
if valid_dataset.get_num_features() != n_features:
Expand Down Expand Up @@ -417,7 +422,7 @@ def prepare_data(example, device_id):
return position, particle_type, material_property, n_particles_per_example, labels


def train(rank, cfg, world_size, device, verbose):
def train(rank, cfg, world_size, device, verbose, use_dist):
"""Train the model.
Args:
Expand All @@ -426,11 +431,12 @@ def train(rank, cfg, world_size, device, verbose):
world_size: total number of ranks
device: torch device type
verbose: global rank 0 or cpu
use_dist: use torch.distribute
"""
device_id = rank if device == torch.device("cuda") else device

# Initialize simulator and optimizer
simulator, optimizer, metadata = initialize_training(cfg, rank, world_size, device)
simulator, optimizer, metadata = initialize_training(cfg, rank, world_size, device, use_dist)

# Initialize training state
step = 0
Expand Down Expand Up @@ -493,27 +499,26 @@ def train(rank, cfg, world_size, device, verbose):
simulator.train()
simulator.to(device_id)

# Determine if we're using distributed training
is_distributed = device == torch.device("cuda") and world_size > 1
# Load datasets
train_dl, valid_dl, n_features = load_datasets(cfg, is_distributed)
train_dl, valid_dl, n_features = load_datasets(cfg, use_dist)

print(f"rank = {rank}, cuda = {torch.cuda.is_available()}")

writer = setup_tensorboard(cfg, metadata) if verbose else None

try:
num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl))
print(f"Total epochs = {num_epochs}")
for epoch in tqdm(range(epoch, num_epochs), desc="Training", unit="epoch"):
if device == torch.device("cuda"):
if verbose:
print(f"Total epochs = {num_epochs}")
for epoch in tqdm(range(epoch, num_epochs), desc="Training", unit="epoch", disable=not verbose):
if use_dist:
torch.distributed.barrier()

epoch_loss = 0.0
steps_this_epoch = 0

# Create a tqdm progress bar for each epoch
with tqdm(total=len(train_dl), desc=f"Epoch {epoch}", unit="batch") as pbar:
with tqdm(total=len(train_dl), desc=f"Epoch {epoch}", unit="batch", disable=not verbose) as pbar:
for example in train_dl:
steps_per_epoch += 1
# Prepare data
Expand Down Expand Up @@ -541,10 +546,10 @@ def train(rank, cfg, world_size, device, verbose):
)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

device_or_rank = rank if device == torch.device("cuda") else device
device_or_rank = rank if device == torch.device("cuda") else device
predict_fn = (
simulator.module.predict_accelerations
if device == torch.device("cuda")
if use_dist
else simulator.predict_accelerations
)
pred_acc, target_acc = predict_fn(
Expand Down Expand Up @@ -626,6 +631,7 @@ def train(rank, cfg, world_size, device, verbose):
valid_loss,
train_loss_hist,
valid_loss_hist,
use_dist,
)

step += 1
Expand All @@ -634,7 +640,7 @@ def train(rank, cfg, world_size, device, verbose):

# Epoch level statistics
avg_loss = torch.tensor([epoch_loss / steps_this_epoch]).to(device_id)
if device == torch.device("cuda"):
if use_dist:
torch.distributed.reduce(
avg_loss, dst=0, op=torch.distributed.ReduceOp.SUM
)
Expand Down Expand Up @@ -679,12 +685,13 @@ def train(rank, cfg, world_size, device, verbose):
valid_loss,
train_loss_hist,
valid_loss_hist,
use_dist,
)

if verbose:
writer.close()

if torch.cuda.is_available():
if use_dist:
distribute.cleanup()


Expand Down Expand Up @@ -799,6 +806,8 @@ def validation(simulator, example, n_features, cfg, rank, device_id):
def main(cfg: Config):
"""Train or evaluates the model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_rank = 0
use_dist = ("LOCAL_RANK" in os.environ)
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])

Expand All @@ -822,7 +831,7 @@ def main(cfg: Config):
world_size = 1
verbose = True

train(local_rank, cfg, world_size, device, verbose)
train(local_rank, cfg, world_size, device, verbose, use_dist)

elif cfg.mode in ["valid", "rollout"]:
# Set device
Expand Down

0 comments on commit 7f31242

Please sign in to comment.