Skip to content

Commit

Permalink
WIP: Reducing the runtime of finding a good search & removing redanda…
Browse files Browse the repository at this point in the history
…nt log infos
  • Loading branch information
Demirrr committed Nov 28, 2024
1 parent 5551241 commit fbf30ab
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 33 deletions.
20 changes: 10 additions & 10 deletions dicee/sanity_checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def validate_knowledge_graph(args):

elif args.path_single_kg is not None:
if args.sparql_endpoint is not None or args.path_single_kg is not None:
print(f'The dataset_dir and sparql_endpoint arguments '
f'must be None if path_single_kg is given.'
f'***{args.dataset_dir}***\n'
f'***{args.sparql_endpoint}***\n'
f'These two parameters are set to None.')
#print(f'The dataset_dir and sparql_endpoint arguments '
# f'must be None if path_single_kg is given.'
# f'***{args.dataset_dir}***\n'
# f'***{args.sparql_endpoint}***\n'
# f'These two parameters are set to None.')
args.dataset_dir = None
args.sparql_endpoint = None

Expand All @@ -61,11 +61,11 @@ def validate_knowledge_graph(args):
f"Use --path_single_kg **folder/dataset.format**, if you have a single file.")

if args.sparql_endpoint is not None or args.path_single_kg is not None:
print(f'The sparql_endpoint and path_single_kg arguments '
f'must be None if dataset_dir is given.'
f'***{args.sparql_endpoint}***\n'
f'***{args.path_single_kg}***\n'
f'These two parameters are set to None.')
#print(f'The sparql_endpoint and path_single_kg arguments '
# f'must be None if dataset_dir is given.'
# f'***{args.sparql_endpoint}***\n'
# f'***{args.path_single_kg}***\n'
# f'These two parameters are set to None.')
args.sparql_endpoint = None
args.path_single_kg = None

Expand Down
78 changes: 55 additions & 23 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..static_funcs_training import make_iterable_verbose
from ..models.ensemble import EnsembleKGE
from typing import Tuple
import time

def extract_input_outputs(z: list, device=None):
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
Expand All @@ -27,59 +28,79 @@ def extract_input_outputs(z: list, device=None):
def find_good_batch_size(train_loader,tp_ensemble_model):
# () Initial batch size
initial_batch_size=train_loader.batch_size
if initial_batch_size >= len(train_loader.dataset):
return initial_batch_size
training_dataset_size=len(train_loader.dataset)
if initial_batch_size >= training_dataset_size:
return training_dataset_size, None
print("Number of training data points:",training_dataset_size)

def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size, delta: int = None):
def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None):
assert delta is not None, "delta must be positive integer"
batch_sizes_and_mem_usages = []
num_datapoints = len(train_loader.dataset)
try:
while True:
start_time=time.time()
# () Initialize a dataloader with a current batch_size
train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=0,
num_workers=train_loader.num_workers,
collate_fn=train_loader.dataset.collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
persistent_workers=False)

batch_loss = None
for i, batch_of_training_data in enumerate(train_dataloaders):
batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model)
break

global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory

print(
f"Random Batch Loss: {batch_loss}\tGPU Usage: {percentage_used_gpu_memory}\tBatch Size:{batch_size}")
rt=time.time()-start_time
print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}")

global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory

batch_sizes_and_mem_usages.append((batch_size, percentage_used_gpu_memory))

# Store the batch size and the runtime
batch_sizes_and_mem_usages.append((batch_size, rt))

if batch_size < num_datapoints:
# Increase the batch size.
batch_size += int(batch_size / delta)
else:
if batch_size == num_datapoints:
print("Batch size equals to the training dataset size")
break
return batch_sizes_and_mem_usages,True

except torch.OutOfMemoryError:
print(f"torch.OutOfMemoryError caught!")
return batch_sizes_and_mem_usages
return batch_sizes_and_mem_usages, False

history_batch_sizes_and_mem_usages=[]
batch_size=initial_batch_size
for delta in range(1,10,1):
history_batch_sizes_and_mem_usages.extend(increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta))
batch_size=history_batch_sizes_and_mem_usages[-2][0]
print(f"A best found batch size:{batch_size} in {len(history_batch_sizes_and_mem_usages)} trials. Current GPU memory usage % :{history_batch_sizes_and_mem_usages[-2][1]}")
return batch_size

for delta in range(1,5,1):
result,flag= increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta)

history_batch_sizes_and_mem_usages.extend(result)

if flag:
batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1]
else:
# CUDA ERROR Observed
batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2]

if batch_size>=training_dataset_size:
batch_size=training_dataset_size
break
else:
continue

return batch_size, batch_rt


def forward_backward_update_loss(z:Tuple, ensemble_model)->float:
Expand Down Expand Up @@ -115,12 +136,14 @@ def fit(self, *args, **kwargs):
self.on_fit_start(self, ensemble_model)
# () Sanity checking
assert torch.cuda.device_count()== len(ensemble_model)
# ()
# () Get DataLoader
train_dataloader = kwargs['train_dataloaders']
# ()
# () Find a batch size so that available GPU memory is *almost* fully used.
if self.attributes.auto_batch_finding:
batch_size, batch_rt=find_good_batch_size(train_dataloader, ensemble_model)

train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset,
batch_size=find_good_batch_size(train_dataloader, ensemble_model),
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
Expand All @@ -131,29 +154,38 @@ def fit(self, *args, **kwargs):
timeout=0,
worker_init_fn=None,
persistent_workers=False)
if batch_rt is not None:
expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")

# () Number of batches to reach a single epoch.
num_of_batches = len(train_dataloader)
# () Start training.
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
verbose=True, position=0, leave=True)):
# () Accumulate the batch losses.
epoch_loss = 0
# () Iterate over batches.
for i, z in enumerate(train_dataloader):
# () Forward, Loss, Backward, and Update on a given batch of data points.
batch_loss = forward_backward_update_loss(z,ensemble_model)
# () Accumulate the batch losses to compute the epoch loss.
epoch_loss += batch_loss

# if verbose=TRue, show info.
if hasattr(tqdm_bar, 'set_description_str'):
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
if i > 0:
tqdm_bar.set_postfix_str(
f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
# Store the epoch loss
ensemble_model.loss_history.append(epoch_loss)

# Run on_fit_end callbacks after the training is done.
self.on_fit_end(self, ensemble_model)
# TODO: Later, maybe we should write a callback to save the models in disk
return ensemble_model

"""
def batchwisefit(self, *args, **kwargs):
Expand Down

0 comments on commit fbf30ab

Please sign in to comment.