Skip to content

Commit

Permalink
TP with auto batch finding can be used to train KGE with >20B
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 29, 2024
1 parent fbf30ab commit f1b263b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
6 changes: 3 additions & 3 deletions dicee/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def get_default_arguments(description=None):
default={},
help='{"PPE":{ "last_percent_to_consider": 10}}'
'"Perturb": {"level": "out", "ratio": 0.2, "method": "RN", "scaler": 0.3}')
parser.add_argument("--trainer", type=str, default='TP',
parser.add_argument("--trainer", type=str, default='PL',
choices=['torchCPUTrainer', 'PL', 'torchDDP', "TP"],
help='PL (pytorch lightning trainer), torchDDP (custom ddp), torchCPUTrainer (custom cpu only), MP (Model Paralelisim)')
parser.add_argument('--scoring_technique', default="KvsSample",
parser.add_argument('--scoring_technique', default="NegSample",
help="Training technique for knowledge graph embedding model",
choices=["AllvsAll", "KvsAll", "1vsAll", "NegSample", "1vsSample", "KvsSample"])
parser.add_argument('--neg_ratio', type=int, default=10,
parser.add_argument('--neg_ratio', type=int, default=2,
help='The number of negative triples generated per positive triple.')
parser.add_argument('--weight_decay', type=float, default=0.0, help='L2 penalty e.g.(0.00001)')
parser.add_argument('--input_dropout_rate', type=float, default=0.0)
Expand Down
31 changes: 22 additions & 9 deletions dicee/trainer/model_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,22 @@ def extract_input_outputs(z: list, device=None):


def find_good_batch_size(train_loader,tp_ensemble_model):
# () Initial batch size
# () Initial batch size.
initial_batch_size=train_loader.batch_size
# () # of training data points.
training_dataset_size=len(train_loader.dataset)
# () Batch is large enough.
if initial_batch_size >= training_dataset_size:
return training_dataset_size, None
# () Log the number of training data points.
print("Number of training data points:",training_dataset_size)

def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None):
assert delta is not None, "delta must be positive integer"
assert delta is not None, "delta cannot be None."
assert isinstance(delta, int), "delta must be a positive integer."
# () Store the batch sizes and GPU memory usages in a tuple.
batch_sizes_and_mem_usages = []
num_datapoints = len(train_loader.dataset)
# () Increase the batch size until a stopping criterion is reached.
try:
while True:
start_time=time.time()
Expand All @@ -62,22 +67,27 @@ def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, b
global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory
rt=time.time()-start_time

print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}")

global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0")
percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory

# Store the batch size and the runtime
batch_sizes_and_mem_usages.append((batch_size, rt))

if batch_size < num_datapoints:
# ()
# https://github.com/pytorch/pytorch/issues/21819
# CD: as we reach close to 1.0 GPU memory usage, we observe RuntimeError: CUDA error: an illegal memory access was encountered.
# CD: To avoid this problem, we add the following condition as a temp solution.
if percentage_used_gpu_memory > 0.9:
# Mimik out of memory error
return batch_sizes_and_mem_usages, False
if batch_size < training_dataset_size:
# Increase the batch size.
batch_size += int(batch_size / delta)
else:
return batch_sizes_and_mem_usages,True

except torch.OutOfMemoryError:
print(f"torch.OutOfMemoryError caught!")
except torch.OutOfMemoryError as e:
print(f"torch.OutOfMemoryError caught! {e}")
return batch_sizes_and_mem_usages, False

history_batch_sizes_and_mem_usages=[]
Expand All @@ -91,8 +101,11 @@ def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, b
if flag:
batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1]
else:
assert len(history_batch_sizes_and_mem_usages)>2, "GPU memory errorin the first try"
# CUDA ERROR Observed
batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2]
# https://github.com/pytorch/pytorch/issues/21819
break

if batch_size>=training_dataset_size:
batch_size=training_dataset_size
Expand Down

0 comments on commit f1b263b

Please sign in to comment.