From 7ad44726102e0ab14365792ce5fcfb009a286607 Mon Sep 17 00:00:00 2001 From: Sounish Nath <40270033+sounishnath003@users.noreply.github.com> Date: Fri, 17 Jun 2022 21:14:34 +0530 Subject: [PATCH] #39 Support For Mac M1 GPU --- tez/model/config.py | 2 +- tez/model/tez.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tez/model/config.py b/tez/model/config.py index 58f7884..cf32e7f 100644 --- a/tez/model/config.py +++ b/tez/model/config.py @@ -5,7 +5,7 @@ @dataclass class TezConfig: experiment_name = "default" - device: Optional[str] = "cuda" # cuda or cpu + device: Optional[str] = "cuda" # cuda or cpu or mps # batch sizes training_batch_size: Optional[int] = 32 diff --git a/tez/model/tez.py b/tez/model/tez.py index 536df2b..3f9304d 100644 --- a/tez/model/tez.py +++ b/tez/model/tez.py @@ -52,6 +52,9 @@ def _configure_model(self): if self.config.device == "cpu": device = torch.device("cpu") self.num_gpu = 0 + elif self.config.device == "mps": + device = torch.device("mps:0") + self.num_gpu = 1 elif self.config.device == "cuda": if torch.cuda.device_count() > 1: if self.local_rank == -1: