Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Sikan Li committed Jul 3, 2024
1 parent 7f31242 commit 2bcd131
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
device_id = rank if device == torch.device("cuda") else device

# Initialize simulator and optimizer
simulator, optimizer, metadata = initialize_training(cfg, rank, world_size, device, use_dist)
simulator, optimizer, metadata = initialize_training(
cfg, rank, world_size, device, use_dist
)

# Initialize training state
step = 0
Expand Down Expand Up @@ -510,15 +512,22 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
num_epochs = max(1, (cfg.training.steps + len(train_dl) - 1) // len(train_dl))
if verbose:
print(f"Total epochs = {num_epochs}")
for epoch in tqdm(range(epoch, num_epochs), desc="Training", unit="epoch", disable=not verbose):
for epoch in tqdm(
range(epoch, num_epochs), desc="Training", unit="epoch", disable=not verbose
):
if use_dist:
torch.distributed.barrier()

epoch_loss = 0.0
steps_this_epoch = 0

# Create a tqdm progress bar for each epoch
with tqdm(total=len(train_dl), desc=f"Epoch {epoch}", unit="batch", disable=not verbose) as pbar:
with tqdm(
total=len(train_dl),
desc=f"Epoch {epoch}",
unit="batch",
disable=not verbose,
) as pbar:
for example in train_dl:
steps_per_epoch += 1
# Prepare data
Expand Down Expand Up @@ -546,7 +555,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist):
)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)

device_or_rank = rank if device == torch.device("cuda") else device
device_or_rank = rank if device == torch.device("cuda") else device
predict_fn = (
simulator.module.predict_accelerations
if use_dist
Expand Down Expand Up @@ -807,7 +816,7 @@ def main(cfg: Config):
"""Train or evaluates the model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_rank = 0
use_dist = ("LOCAL_RANK" in os.environ)
use_dist = "LOCAL_RANK" in os.environ
if "LOCAL_RANK" in os.environ:
local_rank = int(os.environ["LOCAL_RANK"])

Expand Down

0 comments on commit 2bcd131

Please sign in to comment.