Skip to content

Commit

Permalink
Merge pull request #275 from dice-group/tensor_parallel
Browse files Browse the repository at this point in the history
Linear batch size finding for Tensor Parallel Training
  • Loading branch information
Demirrr authored Nov 26, 2024
2 parents abc5226 + 0417f19 commit f38bfa8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
31 changes: 22 additions & 9 deletions dicee/models/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
import torch
import copy

import torch._dynamo

torch._dynamo.config.suppress_errors = True


class EnsembleKGE:
def __init__(self, seed_model):
self.models = []
self.optimizers = []
self.loss_history = []
for i in range(torch.cuda.device_count()):
i_model=copy.deepcopy(seed_model)
i_model.to(torch.device(f"cuda:{i}"))
# TODO: Why we cant send the compile model to cpu ?
# i_model = torch.compile(i_model)
i_model = torch.compile(i_model)
i_model.to(torch.device(f"cuda:{i}"))
self.optimizers.append(i_model.configure_optimizers())
self.models.append(i_model)
# Maybe use the original model's name ?
self.name="TP_"+self.models[0].name
self.name=self.models[0].name
self.train_mode=True

def named_children(self):
Expand Down Expand Up @@ -87,7 +82,25 @@ def __call__(self,x_batch):
def step(self):
for opt in self.optimizers:
opt.step()


def get_embeddings(self):
entity_embeddings=[]
relation_embeddings=[]
# () Iterate
for trained_model in self.models:
entity_emb, relation_ebm = trained_model.get_embeddings()
entity_embeddings.append(entity_emb)
if relation_ebm is not None:
relation_embeddings.append(relation_ebm)
# () Concat the embedding vectors horizontally.
entity_embeddings=torch.cat(entity_embeddings,dim=1)
if relation_embeddings:
relation_embeddings=torch.cat(relation_embeddings,dim=1)
else:
relation_embeddings=None

return entity_embeddings, relation_embeddings

"""
def __getattr__(self, name):
# Create a function that will call the same attribute/method on each model
Expand Down
29 changes: 17 additions & 12 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ def extract_input_outputs(z: list, device=None):
else:
raise ValueError('Unexpected batch shape..')

def find_good_batch_size(train_loader,ensemble_model,max_available_gpu_memory:float=0.05):
def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1):
# () Initial batch size
batch_size=train_loader.batch_size
if batch_size >= len(train_loader.dataset):
return batch_size
first_batch_size = train_loader.batch_size

print("Automatic batch size finding")
for n in range(200):
while True:
# () Initialize a dataloader with a current batch_size
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
batch_size=batch_size,
Expand All @@ -41,26 +45,27 @@ def find_good_batch_size(train_loader,ensemble_model,max_available_gpu_memory:fl
worker_init_fn=None,
persistent_workers=False)
loss=None
avg_global_free_memory=[]
for i, z in enumerate(train_dataloaders):
loss = forward_backward_update_loss(z,ensemble_model)
break
global_free_memory, total_memory = torch.cuda.mem_get_info()
available_gpu_memory = global_free_memory / total_memory
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {available_gpu_memory}\tBatch Size:{batch_size}")
global_free_memory, total_memory = torch.cuda.mem_get_info()
avg_global_free_memory.append(global_free_memory / total_memory)
if i==3:
break

avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory)
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}")
# () Stepping criterion
if available_gpu_memory > max_available_gpu_memory and batch_size < len(train_loader.dataset) :
if avg_global_free_memory > max_available_gpu_memory and batch_size < len(train_loader.dataset) :
# Increment the current batch size
batch_size+=batch_size
batch_size+=first_batch_size
else:
if batch_size >= len(train_loader.dataset):
print("Batch size equals to the training dataset size")
else:
print(f"Max GPU memory used\tFree/Total GPU Memory:{available_gpu_memory}")

print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}")
return batch_size

raise RuntimeError("The computation should be here!")

def forward_backward_update_loss(z:Tuple, ensemble_model):
# () Get the i-th batch of data points.
x_batch, y_batch = extract_input_outputs(z)
Expand Down

0 comments on commit f38bfa8

Please sign in to comment.