Skip to content

Commit

Permalink
avg of last three batches gpu usage measured
Browse files Browse the repository at this point in the history
Demirrr committed Nov 25, 2024
1 parent 4e518bf commit 4b1a876
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ def extract_input_outputs(z: list, device=None):
else:
raise ValueError('Unexpected batch shape..')

def find_good_batch_size(train_loader,ensemble_model,max_available_gpu_memory:float=0.1):
def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1):
# () Initial batch size
batch_size=train_loader.batch_size
first_batch_size = train_loader.batch_size
@@ -43,21 +43,24 @@ def find_good_batch_size(train_loader,ensemble_model,max_available_gpu_memory:fl
worker_init_fn=None,
persistent_workers=False)
loss=None
avg_global_free_memory=[]
for i, z in enumerate(train_dataloaders):
loss = forward_backward_update_loss(z,ensemble_model)
break
global_free_memory, total_memory = torch.cuda.mem_get_info()
available_gpu_memory = global_free_memory / total_memory
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {available_gpu_memory}\tBatch Size:{batch_size}")
if i==3:
global_free_memory, total_memory = torch.cuda.mem_get_info()
avg_global_free_memory.append(global_free_memory / total_memory)
break
avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory)
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}")
# () Stepping criterion
if available_gpu_memory > max_available_gpu_memory and batch_size < len(train_loader.dataset) :
if avg_global_free_memory > max_available_gpu_memory and batch_size < len(train_loader.dataset) :
# Increment the current batch size
batch_size+=first_batch_size
else:
if batch_size >= len(train_loader.dataset):
print("Batch size equals to the training dataset size")
else:
print(f"Max GPU memory used\tFree/Total GPU Memory:{available_gpu_memory}")
print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}")

return batch_size

0 comments on commit 4b1a876

Please sign in to comment.