diff --git a/tez/model/config.py b/tez/model/config.py index 58f7884..cf32e7f 100644 --- a/tez/model/config.py +++ b/tez/model/config.py @@ -5,7 +5,7 @@ @dataclass class TezConfig: experiment_name = "default" - device: Optional[str] = "cuda" # cuda or cpu + device: Optional[str] = "cuda" # cuda or cpu or mps # batch sizes training_batch_size: Optional[int] = 32 diff --git a/tez/model/tez.py b/tez/model/tez.py index 536df2b..3f9304d 100644 --- a/tez/model/tez.py +++ b/tez/model/tez.py @@ -52,6 +52,9 @@ def _configure_model(self): if self.config.device == "cpu": device = torch.device("cpu") self.num_gpu = 0 + elif self.config.device == "mps": + device = torch.device("mps:0") + self.num_gpu = 1 elif self.config.device == "cuda": if torch.cuda.device_count() > 1: if self.local_rank == -1: