diff --git a/dicee/models/ensemble.py b/dicee/models/ensemble.py index 83df1a3e..072f91bd 100644 --- a/dicee/models/ensemble.py +++ b/dicee/models/ensemble.py @@ -82,6 +82,9 @@ def __call__(self,x_batch): def step(self): for opt in self.optimizers: opt.step() + + def get_embeddings(self): + raise NotImplementedError("Not yet Implemented") """ def __getattr__(self, name): diff --git a/dicee/trainer/model_parallelism.py b/dicee/trainer/model_parallelism.py index f331c3f0..ffea1a46 100644 --- a/dicee/trainer/model_parallelism.py +++ b/dicee/trainer/model_parallelism.py @@ -26,6 +26,8 @@ def extract_input_outputs(z: list, device=None): def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1): # () Initial batch size batch_size=train_loader.batch_size + if batch_size >= len(train_loader.dataset): + return batch_size first_batch_size = train_loader.batch_size print("Automatic batch size finding") @@ -46,10 +48,11 @@ def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:f avg_global_free_memory=[] for i, z in enumerate(train_dataloaders): loss = forward_backward_update_loss(z,ensemble_model) + global_free_memory, total_memory = torch.cuda.mem_get_info() + avg_global_free_memory.append(global_free_memory / total_memory) if i==3: - global_free_memory, total_memory = torch.cuda.mem_get_info() - avg_global_free_memory.append(global_free_memory / total_memory) break + avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory) print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}") # () Stepping criterion @@ -61,11 +64,8 @@ def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:f print("Batch size equals to the training dataset size") else: print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}") - return batch_size - raise RuntimeError("The computation should be here!") - def forward_backward_update_loss(z:Tuple, ensemble_model): # () Get the i-th batch of data points. x_batch, y_batch = extract_input_outputs(z)