Skip to content

Commit

Permalink
Merge pull request #280 from dice-group/tensor_parallel
Browse files Browse the repository at this point in the history
Tensor parallel
  • Loading branch information
Demirrr authored Nov 28, 2024
2 parents 94ab305 + 44b9dbd commit 58aa98c
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 68 deletions.
2 changes: 1 addition & 1 deletion dicee/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, seed_model):
for i in range(torch.cuda.device_count()):
i_model=copy.deepcopy(seed_model)
# TODO: Why we cant send the compile model to cpu ?
i_model = torch.compile(i_model)
#i_model = torch.compile(i_model)
i_model.to(torch.device(f"cuda:{i}"))
self.optimizers.append(i_model.configure_optimizers())
self.models.append(i_model)
Expand Down
20 changes: 10 additions & 10 deletions dicee/sanity_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def validate_knowledge_graph(args):

elif args.path_single_kg is not None:
if args.sparql_endpoint is not None or args.path_single_kg is not None:
print(f'The dataset_dir and sparql_endpoint arguments '
f'must be None if path_single_kg is given.'
f'***{args.dataset_dir}***\n'
f'***{args.sparql_endpoint}***\n'
f'These two parameters are set to None.')
#print(f'The dataset_dir and sparql_endpoint arguments '
# f'must be None if path_single_kg is given.'
# f'***{args.dataset_dir}***\n'
# f'***{args.sparql_endpoint}***\n'
# f'These two parameters are set to None.')
args.dataset_dir = None
args.sparql_endpoint = None

Expand All @@ -61,11 +61,11 @@ def validate_knowledge_graph(args):
f"Use --path_single_kg **folder/dataset.format**, if you have a single file.")

if args.sparql_endpoint is not None or args.path_single_kg is not None:
print(f'The sparql_endpoint and path_single_kg arguments '
f'must be None if dataset_dir is given.'
f'***{args.sparql_endpoint}***\n'
f'***{args.path_single_kg}***\n'
f'These two parameters are set to None.')
#print(f'The sparql_endpoint and path_single_kg arguments '
# f'must be None if dataset_dir is given.'
# f'***{args.sparql_endpoint}***\n'
# f'***{args.path_single_kg}***\n'
# f'These two parameters are set to None.')
args.sparql_endpoint = None
args.path_single_kg = None

Expand Down
158 changes: 101 additions & 57 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..static_funcs_training import make_iterable_verbose
from ..models.ensemble import EnsembleKGE
from typing import Tuple
import time

def extract_input_outputs(z: list, device=None):
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
Expand All @@ -23,67 +24,99 @@ def extract_input_outputs(z: list, device=None):
else:
raise ValueError('Unexpected batch shape..')

def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1):

def find_good_batch_size(train_loader,tp_ensemble_model):
# () Initial batch size
batch_size=train_loader.batch_size
if batch_size >= len(train_loader.dataset):
return batch_size
first_batch_size = train_loader.batch_size
num_datapoints=len(train_loader.dataset)
print(f"Increment the batch size by {first_batch_size} until the Free/Total GPU memory is reached to {1-max_available_gpu_memory} or batch_size={num_datapoints} is achieved.")
while True:
# () Initialize a dataloader with a current batch_size
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=train_loader.dataset.collate_fn,
pin_memory=False, drop_last=False,
timeout=0,
worker_init_fn=None,
persistent_workers=False)
loss=None
avg_global_free_memory=[]
for i, z in enumerate(train_dataloaders):
loss = forward_backward_update_loss(z,ensemble_model)
global_free_memory, total_memory = torch.cuda.mem_get_info()
avg_global_free_memory.append(global_free_memory / total_memory)
if i==3:
break
avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory)
print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}")
if avg_global_free_memory > max_available_gpu_memory and batch_size < num_datapoints :
if batch_size+first_batch_size <= num_datapoints:
batch_size+=first_batch_size
else:
batch_size=num_datapoints
initial_batch_size=train_loader.batch_size
training_dataset_size=len(train_loader.dataset)
if initial_batch_size >= training_dataset_size:
return training_dataset_size, None
print("Number of training data points:",training_dataset_size)

def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None):
assert delta is not None, "delta must be positive integer"
batch_sizes_and_mem_usages = []
num_datapoints = len(train_loader.dataset)
try:
while True:
start_time=time.time()
# () Initialize a dataloader with a current batch_size
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=train_loader.num_workers,
collate_fn=train_loader.dataset.collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
persistent_workers=False)

batch_loss = None
for i, batch_of_training_data in enumerate(train_dataloaders):
batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model)
break

global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory
rt=time.time()-start_time
print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}")

global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory

# Store the batch size and the runtime
batch_sizes_and_mem_usages.append((batch_size, rt))

if batch_size < num_datapoints:
# Increase the batch size.
batch_size += int(batch_size / delta)
else:
return batch_sizes_and_mem_usages,True

except torch.OutOfMemoryError:
print("torch.OutOfMemoryError caught!")
return batch_sizes_and_mem_usages, False

history_batch_sizes_and_mem_usages=[]
batch_size=initial_batch_size

for delta in range(1,5,1):
result,flag= increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta)

history_batch_sizes_and_mem_usages.extend(result)

if flag:
batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1]
else:
# CUDA ERROR Observed
batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2]

if batch_size>=training_dataset_size:
batch_size=training_dataset_size
break
else:
assert batch_size<=num_datapoints
if batch_size == num_datapoints:
print("Batch size equals to the training dataset size")
else:
print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}")
return batch_size

def forward_backward_update_loss(z:Tuple, ensemble_model):
# () Get the i-th batch of data points.
continue

return batch_size, batch_rt


def forward_backward_update_loss(z:Tuple, ensemble_model)->float:
# () Get a random batch of data points (z).
x_batch, y_batch = extract_input_outputs(z)
# () Move the batch of labels into the master GPU : GPU-0
# () Move the batch of labels into the master GPU : GPU-0.
y_batch = y_batch.to("cuda:0")
# () Forward Pass on the batch. Yhat located on the master GPU.
# () Forward pas on the batch of input data points (yhat on the master GPU).
yhat = ensemble_model(x_batch)
# () Compute the loss
# () Compute the loss.
loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch)
# () Compute the gradient of the loss w.r.t. parameters.
loss.backward()
# () Parameter update.
ensemble_model.step()
# () Report the batch and epoch losses.
batch_loss = loss.item()
# () Accumulate batch loss
return batch_loss
return loss.item()

class TensorParallel(AbstractTrainer):
def __init__(self, args, callbacks):
Expand All @@ -103,12 +136,14 @@ def fit(self, *args, **kwargs):
self.on_fit_start(self, ensemble_model)
# () Sanity checking
assert torch.cuda.device_count()== len(ensemble_model)
# ()
# () Get DataLoader
train_dataloader = kwargs['train_dataloaders']
# ()
# () Find a batch size so that available GPU memory is *almost* fully used.
if self.attributes.auto_batch_finding:
batch_size, batch_rt=find_good_batch_size(train_dataloader, ensemble_model)

train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset,
batch_size=find_good_batch_size(train_dataloader, ensemble_model),
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
Expand All @@ -119,29 +154,38 @@ def fit(self, *args, **kwargs):
timeout=0,
worker_init_fn=None,
persistent_workers=False)
if batch_rt is not None:
expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")

# () Number of batches to reach a single epoch.
num_of_batches = len(train_dataloader)
# () Start training.
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
verbose=True, position=0, leave=True)):
# () Accumulate the batch losses.
epoch_loss = 0
# () Iterate over batches.
for i, z in enumerate(train_dataloader):
# () Forward, Loss, Backward, and Update on a given batch of data points.
batch_loss = forward_backward_update_loss(z,ensemble_model)
# () Accumulate the batch losses to compute the epoch loss.
epoch_loss += batch_loss

# if verbose=TRue, show info.
if hasattr(tqdm_bar, 'set_description_str'):
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
if i > 0:
tqdm_bar.set_postfix_str(
f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
# Store the epoch loss
ensemble_model.loss_history.append(epoch_loss)

# Run on_fit_end callbacks after the training is done.
self.on_fit_end(self, ensemble_model)
# TODO: Later, maybe we should write a callback to save the models in disk
return ensemble_model

"""
def batchwisefit(self, *args, **kwargs):
Expand Down Expand Up @@ -242,4 +286,4 @@ def torch_buggy_fit(self, *args, **kwargs):
torch.distributed.destroy_process_group()
# () .
self.on_fit_end(self, model)
"""
"""

0 comments on commit 58aa98c

Please sign in to comment.