Skip to content

Commit

Permalink
expoential batch size increment is reduced to the linear
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 26, 2024
1 parent a6e15b7 commit 4e89b1d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 3 additions & 0 deletions dicee/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __call__(self,x_batch):
def step(self):
for opt in self.optimizers:
opt.step()

def get_embeddings(self):
raise NotImplementedError("Not yet Implemented")

"""
def __getattr__(self, name):
Expand Down
10 changes: 5 additions & 5 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def extract_input_outputs(z: list, device=None):
def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1):
# () Initial batch size
batch_size=train_loader.batch_size
if batch_size >= len(train_loader.dataset):
return batch_size
first_batch_size = train_loader.batch_size

print("Automatic batch size finding")
Expand All @@ -46,10 +48,11 @@ def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:f
avg_global_free_memory=[]
for i, z in enumerate(train_dataloaders):
loss = forward_backward_update_loss(z,ensemble_model)
global_free_memory, total_memory = torch.cuda.mem_get_info()
avg_global_free_memory.append(global_free_memory / total_memory)
if i==3:
global_free_memory, total_memory = torch.cuda.mem_get_info()
avg_global_free_memory.append(global_free_memory / total_memory)
break

avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory)
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}")
# () Stepping criterion
Expand All @@ -61,11 +64,8 @@ def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:f
print("Batch size equals to the training dataset size")
else:
print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}")

return batch_size

raise RuntimeError("The computation should be here!")

def forward_backward_update_loss(z:Tuple, ensemble_model):
# () Get the i-th batch of data points.
x_batch, y_batch = extract_input_outputs(z)
Expand Down

0 comments on commit 4e89b1d

Please sign in to comment.