From d3081a111f9352c766e951dcadd0f6acb3ad4260 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 28 Nov 2024 08:11:01 +0000 Subject: [PATCH 1/4] compile is removed and avg is reduced to single in gpu usage signal --- dicee/models/ensemble.py | 2 +- dicee/trainer/model_parallelism.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/dicee/models/ensemble.py b/dicee/models/ensemble.py index 34f41059..2cf4774e 100644 --- a/dicee/models/ensemble.py +++ b/dicee/models/ensemble.py @@ -9,7 +9,7 @@ def __init__(self, seed_model): for i in range(torch.cuda.device_count()): i_model=copy.deepcopy(seed_model) # TODO: Why we cant send the compile model to cpu ? - i_model = torch.compile(i_model) + #i_model = torch.compile(i_model) i_model.to(torch.device(f"cuda:{i}")) self.optimizers.append(i_model.configure_optimizers()) self.models.append(i_model) diff --git a/dicee/trainer/model_parallelism.py b/dicee/trainer/model_parallelism.py index 9b10941e..5626ecb2 100644 --- a/dicee/trainer/model_parallelism.py +++ b/dicee/trainer/model_parallelism.py @@ -23,7 +23,7 @@ def extract_input_outputs(z: list, device=None): else: raise ValueError('Unexpected batch shape..') -def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.1): +def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.20): # () Initial batch size batch_size=train_loader.batch_size if batch_size >= len(train_loader.dataset): @@ -49,18 +49,17 @@ def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:f for i, z in enumerate(train_dataloaders): loss = forward_backward_update_loss(z,ensemble_model) global_free_memory, total_memory = torch.cuda.mem_get_info() - avg_global_free_memory.append(global_free_memory / total_memory) - if i==3: - break - avg_global_free_memory=sum(avg_global_free_memory)/len(avg_global_free_memory) + break + + avg_global_free_memory= global_free_memory / total_memory + print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}") if avg_global_free_memory > max_available_gpu_memory and batch_size < num_datapoints : - if batch_size+first_batch_size <= num_datapoints: - batch_size+=first_batch_size + if batch_size <= num_datapoints: + batch_size+=batch_size else: batch_size=num_datapoints else: - assert batch_size<=num_datapoints if batch_size == num_datapoints: print("Batch size equals to the training dataset size") else: @@ -242,4 +241,4 @@ def torch_buggy_fit(self, *args, **kwargs): torch.distributed.destroy_process_group() # () . self.on_fit_end(self, model) - """ \ No newline at end of file + """ From 5551241dc67808afebddd3b9cfbf2ae675640778 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 28 Nov 2024 10:29:56 +0100 Subject: [PATCH 2/4] Improved batch finding in TP --- dicee/trainer/model_parallelism.py | 115 ++++++++++++++++------------- 1 file changed, 64 insertions(+), 51 deletions(-) diff --git a/dicee/trainer/model_parallelism.py b/dicee/trainer/model_parallelism.py index 5626ecb2..330ed231 100644 --- a/dicee/trainer/model_parallelism.py +++ b/dicee/trainer/model_parallelism.py @@ -23,66 +23,79 @@ def extract_input_outputs(z: list, device=None): else: raise ValueError('Unexpected batch shape..') -def find_good_batch_size(train_loader,ensemble_model, max_available_gpu_memory:float=0.20): + +def find_good_batch_size(train_loader,tp_ensemble_model): # () Initial batch size - batch_size=train_loader.batch_size - if batch_size >= len(train_loader.dataset): - return batch_size - first_batch_size = train_loader.batch_size - num_datapoints=len(train_loader.dataset) - print(f"Increment the batch size by {first_batch_size} until the Free/Total GPU memory is reached to {1-max_available_gpu_memory} or batch_size={num_datapoints} is achieved.") - while True: - # () Initialize a dataloader with a current batch_size - train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset, - batch_size=batch_size, - shuffle=True, - sampler=None, - batch_sampler=None, - num_workers=0, - collate_fn=train_loader.dataset.collate_fn, - pin_memory=False, drop_last=False, - timeout=0, - worker_init_fn=None, - persistent_workers=False) - loss=None - avg_global_free_memory=[] - for i, z in enumerate(train_dataloaders): - loss = forward_backward_update_loss(z,ensemble_model) - global_free_memory, total_memory = torch.cuda.mem_get_info() - break - - avg_global_free_memory= global_free_memory / total_memory - - print(f"Random Batch Loss: {loss}\tFree/Total GPU Memory: {avg_global_free_memory}\tBatch Size:{batch_size}") - if avg_global_free_memory > max_available_gpu_memory and batch_size < num_datapoints : - if batch_size <= num_datapoints: - batch_size+=batch_size - else: - batch_size=num_datapoints - else: - if batch_size == num_datapoints: - print("Batch size equals to the training dataset size") - else: - print(f"Max GPU memory used\tFree/Total GPU Memory:{avg_global_free_memory}") - return batch_size - -def forward_backward_update_loss(z:Tuple, ensemble_model): - # () Get the i-th batch of data points. + initial_batch_size=train_loader.batch_size + if initial_batch_size >= len(train_loader.dataset): + return initial_batch_size + + def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size, delta: int = None): + assert delta is not None, "delta must be positive integer" + batch_sizes_and_mem_usages = [] + num_datapoints = len(train_loader.dataset) + try: + while True: + # () Initialize a dataloader with a current batch_size + train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset, + batch_size=batch_size, + shuffle=True, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=train_loader.dataset.collate_fn, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + persistent_workers=False) + batch_loss = None + for i, batch_of_training_data in enumerate(train_dataloaders): + batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model) + break + global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0") + percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory + + print( + f"Random Batch Loss: {batch_loss}\tGPU Usage: {percentage_used_gpu_memory}\tBatch Size:{batch_size}") + + global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0") + percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory + + batch_sizes_and_mem_usages.append((batch_size, percentage_used_gpu_memory)) + if batch_size < num_datapoints: + batch_size += int(batch_size / delta) + else: + if batch_size == num_datapoints: + print("Batch size equals to the training dataset size") + break + except torch.OutOfMemoryError: + print(f"torch.OutOfMemoryError caught!") + return batch_sizes_and_mem_usages + + history_batch_sizes_and_mem_usages=[] + batch_size=initial_batch_size + for delta in range(1,10,1): + history_batch_sizes_and_mem_usages.extend(increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta)) + batch_size=history_batch_sizes_and_mem_usages[-2][0] + print(f"A best found batch size:{batch_size} in {len(history_batch_sizes_and_mem_usages)} trials. Current GPU memory usage % :{history_batch_sizes_and_mem_usages[-2][1]}") + return batch_size + + +def forward_backward_update_loss(z:Tuple, ensemble_model)->float: + # () Get a random batch of data points (z). x_batch, y_batch = extract_input_outputs(z) - # () Move the batch of labels into the master GPU : GPU-0 + # () Move the batch of labels into the master GPU : GPU-0. y_batch = y_batch.to("cuda:0") - # () Forward Pass on the batch. Yhat located on the master GPU. + # () Forward pas on the batch of input data points (yhat on the master GPU). yhat = ensemble_model(x_batch) - # () Compute the loss + # () Compute the loss. loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch) # () Compute the gradient of the loss w.r.t. parameters. loss.backward() # () Parameter update. ensemble_model.step() - # () Report the batch and epoch losses. - batch_loss = loss.item() - # () Accumulate batch loss - return batch_loss + return loss.item() class TensorParallel(AbstractTrainer): def __init__(self, args, callbacks): From fbf30ab763cc8816737b2718a0bf6069649dbb79 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 28 Nov 2024 13:59:23 +0000 Subject: [PATCH 3/4] WIP: Reducing the runtime of finding a good search & removing redandant log infos --- dicee/sanity_checkers.py | 20 ++++---- dicee/trainer/model_parallelism.py | 78 +++++++++++++++++++++--------- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/dicee/sanity_checkers.py b/dicee/sanity_checkers.py index 70dc94b2..1e1463aa 100644 --- a/dicee/sanity_checkers.py +++ b/dicee/sanity_checkers.py @@ -32,11 +32,11 @@ def validate_knowledge_graph(args): elif args.path_single_kg is not None: if args.sparql_endpoint is not None or args.path_single_kg is not None: - print(f'The dataset_dir and sparql_endpoint arguments ' - f'must be None if path_single_kg is given.' - f'***{args.dataset_dir}***\n' - f'***{args.sparql_endpoint}***\n' - f'These two parameters are set to None.') + #print(f'The dataset_dir and sparql_endpoint arguments ' + # f'must be None if path_single_kg is given.' + # f'***{args.dataset_dir}***\n' + # f'***{args.sparql_endpoint}***\n' + # f'These two parameters are set to None.') args.dataset_dir = None args.sparql_endpoint = None @@ -61,11 +61,11 @@ def validate_knowledge_graph(args): f"Use --path_single_kg **folder/dataset.format**, if you have a single file.") if args.sparql_endpoint is not None or args.path_single_kg is not None: - print(f'The sparql_endpoint and path_single_kg arguments ' - f'must be None if dataset_dir is given.' - f'***{args.sparql_endpoint}***\n' - f'***{args.path_single_kg}***\n' - f'These two parameters are set to None.') + #print(f'The sparql_endpoint and path_single_kg arguments ' + # f'must be None if dataset_dir is given.' + # f'***{args.sparql_endpoint}***\n' + # f'***{args.path_single_kg}***\n' + # f'These two parameters are set to None.') args.sparql_endpoint = None args.path_single_kg = None diff --git a/dicee/trainer/model_parallelism.py b/dicee/trainer/model_parallelism.py index 330ed231..bac1b9f6 100644 --- a/dicee/trainer/model_parallelism.py +++ b/dicee/trainer/model_parallelism.py @@ -3,6 +3,7 @@ from ..static_funcs_training import make_iterable_verbose from ..models.ensemble import EnsembleKGE from typing import Tuple +import time def extract_input_outputs(z: list, device=None): # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) @@ -27,59 +28,79 @@ def extract_input_outputs(z: list, device=None): def find_good_batch_size(train_loader,tp_ensemble_model): # () Initial batch size initial_batch_size=train_loader.batch_size - if initial_batch_size >= len(train_loader.dataset): - return initial_batch_size + training_dataset_size=len(train_loader.dataset) + if initial_batch_size >= training_dataset_size: + return training_dataset_size, None + print("Number of training data points:",training_dataset_size) - def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size, delta: int = None): + def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None): assert delta is not None, "delta must be positive integer" batch_sizes_and_mem_usages = [] num_datapoints = len(train_loader.dataset) try: while True: + start_time=time.time() # () Initialize a dataloader with a current batch_size train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=True, sampler=None, batch_sampler=None, - num_workers=0, + num_workers=train_loader.num_workers, collate_fn=train_loader.dataset.collate_fn, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, persistent_workers=False) + batch_loss = None for i, batch_of_training_data in enumerate(train_dataloaders): batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model) break + global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0") percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory - - print( - f"Random Batch Loss: {batch_loss}\tGPU Usage: {percentage_used_gpu_memory}\tBatch Size:{batch_size}") + rt=time.time()-start_time + print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}") global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0") percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory - - batch_sizes_and_mem_usages.append((batch_size, percentage_used_gpu_memory)) + + # Store the batch size and the runtime + batch_sizes_and_mem_usages.append((batch_size, rt)) + if batch_size < num_datapoints: + # Increase the batch size. batch_size += int(batch_size / delta) else: - if batch_size == num_datapoints: - print("Batch size equals to the training dataset size") - break + return batch_sizes_and_mem_usages,True + except torch.OutOfMemoryError: print(f"torch.OutOfMemoryError caught!") - return batch_sizes_and_mem_usages + return batch_sizes_and_mem_usages, False history_batch_sizes_and_mem_usages=[] batch_size=initial_batch_size - for delta in range(1,10,1): - history_batch_sizes_and_mem_usages.extend(increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta)) - batch_size=history_batch_sizes_and_mem_usages[-2][0] - print(f"A best found batch size:{batch_size} in {len(history_batch_sizes_and_mem_usages)} trials. Current GPU memory usage % :{history_batch_sizes_and_mem_usages[-2][1]}") - return batch_size + + for delta in range(1,5,1): + result,flag= increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta) + + history_batch_sizes_and_mem_usages.extend(result) + + if flag: + batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1] + else: + # CUDA ERROR Observed + batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2] + + if batch_size>=training_dataset_size: + batch_size=training_dataset_size + break + else: + continue + + return batch_size, batch_rt def forward_backward_update_loss(z:Tuple, ensemble_model)->float: @@ -115,12 +136,14 @@ def fit(self, *args, **kwargs): self.on_fit_start(self, ensemble_model) # () Sanity checking assert torch.cuda.device_count()== len(ensemble_model) - # () + # () Get DataLoader train_dataloader = kwargs['train_dataloaders'] - # () + # () Find a batch size so that available GPU memory is *almost* fully used. if self.attributes.auto_batch_finding: + batch_size, batch_rt=find_good_batch_size(train_dataloader, ensemble_model) + train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset, - batch_size=find_good_batch_size(train_dataloader, ensemble_model), + batch_size=batch_size, shuffle=True, sampler=None, batch_sampler=None, @@ -131,17 +154,24 @@ def fit(self, *args, **kwargs): timeout=0, worker_init_fn=None, persistent_workers=False) + if batch_rt is not None: + expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs + print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}") + # () Number of batches to reach a single epoch. num_of_batches = len(train_dataloader) # () Start training. for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): + # () Accumulate the batch losses. epoch_loss = 0 # () Iterate over batches. for i, z in enumerate(train_dataloader): + # () Forward, Loss, Backward, and Update on a given batch of data points. batch_loss = forward_backward_update_loss(z,ensemble_model) + # () Accumulate the batch losses to compute the epoch loss. epoch_loss += batch_loss - + # if verbose=TRue, show info. if hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: @@ -149,11 +179,13 @@ def fit(self, *args, **kwargs): f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") + # Store the epoch loss ensemble_model.loss_history.append(epoch_loss) - + # Run on_fit_end callbacks after the training is done. self.on_fit_end(self, ensemble_model) # TODO: Later, maybe we should write a callback to save the models in disk return ensemble_model + """ def batchwisefit(self, *args, **kwargs): From 44b9dbd94c8aa4d44cd71661f317ea0a950db868 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 28 Nov 2024 15:00:52 +0100 Subject: [PATCH 4/4] fstring usage without any placeholder fixed --- dicee/trainer/model_parallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dicee/trainer/model_parallelism.py b/dicee/trainer/model_parallelism.py index bac1b9f6..dd4a2929 100644 --- a/dicee/trainer/model_parallelism.py +++ b/dicee/trainer/model_parallelism.py @@ -77,7 +77,7 @@ def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, b return batch_sizes_and_mem_usages,True except torch.OutOfMemoryError: - print(f"torch.OutOfMemoryError caught!") + print("torch.OutOfMemoryError caught!") return batch_sizes_and_mem_usages, False history_batch_sizes_and_mem_usages=[]