Skip to content

Commit

Permalink
Fix devices
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 17, 2024
1 parent 98a49d5 commit db5c319
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions darts/src/darts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train_smp(
plot_every_n_val_epochs: int = 5,
# Device and Manager config
num_workers: int = 0,
devices: int | str = "auto",
device: int | str = "auto",
wandb_entity: str | None = None,
wandb_project: str | None = None,
run_name: str | None = None,
Expand Down Expand Up @@ -281,7 +281,7 @@ def train_smp(
Defaults to 5.
plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
devices (int | str | list[int], optional): The device(s) to run the model on. Defaults to "auto".
device (int | str, optional): The device to run the model on. Defaults to "auto".
wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.
run_name (str | None, optional): Name of this run, as a further grouping method for logs etc. Defaults to None.
Expand All @@ -306,7 +306,7 @@ def train_smp(
f"Using config:\n\t{model_arch=}\n\t{model_encoder=}\n\t{model_encoder_weights=}\n\t{augment=}\n\t"
f"{learning_rate=}\n\t{gamma=}\n\t{batch_size=}\n\t{max_epochs=}\n\t{log_every_n_steps=}\n\t"
f"{check_val_every_n_epoch=}\n\t{early_stopping_patience=}\n\t{plot_every_n_val_epochs=}\n\t{num_workers=}"
f"\n\t{devices=}"
f"\n\t{device=}"
)

lovely_tensors.monkey_patch()
Expand Down Expand Up @@ -365,7 +365,8 @@ def train_smp(
log_every_n_steps=log_every_n_steps,
logger=trainer_loggers,
check_val_every_n_epoch=check_val_every_n_epoch,
devices=devices,
accelerator="gpu" if isinstance(device, int) else device,
devices=[device] if isinstance(device, int) else device,
)
trainer.fit(model, datamodule)

Expand Down

0 comments on commit db5c319

Please sign in to comment.