From 71780cd091d3c54a5081440f5c11dc76dc82e95d Mon Sep 17 00:00:00 2001 From: fatsmcgee Date: Fri, 1 Nov 2024 22:53:12 -0400 Subject: [PATCH 1/6] Speed up ArtifactDataset constructor with larger batch size, inference mode --- permutect/data/artifact_dataset.py | 18 +++++++++++++----- permutect/tools/train_model.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/permutect/data/artifact_dataset.py b/permutect/data/artifact_dataset.py index 1a1c3d72..d819305a 100644 --- a/permutect/data/artifact_dataset.py +++ b/permutect/data/artifact_dataset.py @@ -1,9 +1,10 @@ import math import random from typing import List -from tqdm.autonotebook import tqdm import numpy as np +import torch +from tqdm.autonotebook import tqdm from torch.utils.data import Dataset, DataLoader, Sampler from permutect.architecture.base_model import BaseModel @@ -14,7 +15,11 @@ # given a ReadSetDataset, apply a BaseModel to get an ArtifactDataset (in RAM, maybe implement memory map later) # of RepresentationReadSets class ArtifactDataset(Dataset): - def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_use: List[int] = None): + def __init__(self, base_dataset: BaseDataset, + base_model: BaseModel, + folds_to_use: List[int] = None, + base_loader_num_workers=0, + base_loader_batch_size=4096): self.counts_by_source = base_dataset.counts_by_source self.totals = base_dataset.totals self.source_totals = base_dataset.source_totals @@ -30,7 +35,9 @@ def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_us index = 0 - loader = base_dataset.make_data_loader(base_dataset.all_folds() if folds_to_use is None else folds_to_use, batch_size=256) + loader = base_dataset.make_data_loader(base_dataset.all_folds() if folds_to_use is None else folds_to_use, + batch_size=base_loader_batch_size, + num_workers=base_loader_num_workers) print("making artifact dataset from base dataset") is_cuda = base_model._device.type == 'cuda' @@ -38,8 +45,9 @@ def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_us pbar = tqdm(enumerate(loader), mininterval=60) for n, base_batch_cpu in pbar: - base_batch = base_batch_cpu.copy_to(base_model._device, non_blocking=base_model._device.type == 'cuda') - representations, ref_alt_seq_embeddings = base_model.calculate_representations(base_batch) + base_batch = base_batch_cpu.copy_to(base_model._device, non_blocking=is_cuda) + with torch.inference_mode(): + representations, ref_alt_seq_embeddings = base_model.calculate_representations(base_batch) for representation, ref_alt_emb, base_datum in zip(representations.detach().cpu(), ref_alt_seq_embeddings.detach().cpu(), base_batch_cpu.original_list()): artifact_datum = ArtifactDatum(base_datum, representation.detach(), ref_alt_emb) self.artifact_data.append(artifact_datum) diff --git a/permutect/tools/train_model.py b/permutect/tools/train_model.py index e7d75273..d21bbba7 100644 --- a/permutect/tools/train_model.py +++ b/permutect/tools/train_model.py @@ -97,7 +97,7 @@ def main_without_parsing(args): print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}") base_dataset = BaseDataset(data_tarfile=getattr(args, constants.TRAIN_TAR_NAME), num_folds=10) print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") - artifact_dataset = ArtifactDataset(base_dataset, base_model) + artifact_dataset = ArtifactDataset(base_dataset, base_model, base_loader_num_workers=training_params.num_workers) print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") model = train_artifact_model(hyperparams=params, training_params=training_params, summary_writer=summary_writer, dataset=artifact_dataset) From 27b64c3a19cd811ae65699992c37919f96072d7e Mon Sep 17 00:00:00 2001 From: fatsmcgee Date: Fri, 1 Nov 2024 23:17:43 -0400 Subject: [PATCH 2/6] Get rid of unused import, more hyooge batch size --- permutect/data/artifact_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/permutect/data/artifact_dataset.py b/permutect/data/artifact_dataset.py index d819305a..3a4cd32a 100644 --- a/permutect/data/artifact_dataset.py +++ b/permutect/data/artifact_dataset.py @@ -2,7 +2,6 @@ import random from typing import List -import numpy as np import torch from tqdm.autonotebook import tqdm from torch.utils.data import Dataset, DataLoader, Sampler @@ -19,7 +18,7 @@ def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_use: List[int] = None, base_loader_num_workers=0, - base_loader_batch_size=4096): + base_loader_batch_size=8192): self.counts_by_source = base_dataset.counts_by_source self.totals = base_dataset.totals self.source_totals = base_dataset.source_totals From 94eba196a8d5183b4451643e0025da79bdf50303 Mon Sep 17 00:00:00 2001 From: fatsmcgee Date: Sat, 2 Nov 2024 13:59:31 -0400 Subject: [PATCH 3/6] See if storing only the needed original values delivers the serialization optimization --- permutect/data/base_datum.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/permutect/data/base_datum.py b/permutect/data/base_datum.py index e717e9e8..202667f0 100644 --- a/permutect/data/base_datum.py +++ b/permutect/data/base_datum.py @@ -660,7 +660,10 @@ def is_labeled(self): class ArtifactBatch: def __init__(self, data: List[ArtifactDatum]): - self.original_data = data + + #self.original_data = data + self.original_variants = [d.get_other_stuff_1d().get_variant() for d in data] + self.original_counts_and_seq_lks = [d.get_other_stuff_1d().get_counts_and_seq_lks() for d in data] self.representations_2d = torch.vstack([item.representation for item in data]) self.ref_alt_seq_embeddings_2d = torch.vstack([item.ref_alt_seq_embedding for item in data]) From 768948ad3e31050695744b58b7221ed483a4290d Mon Sep 17 00:00:00 2001 From: fatsmcgee Date: Sat, 2 Nov 2024 14:12:48 -0400 Subject: [PATCH 4/6] Yep, it works. Only store needed original data in ArtifactBatch --- permutect/architecture/artifact_model.py | 5 ++--- permutect/data/base_datum.py | 1 - permutect/tools/filter_variants.py | 8 +++++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/permutect/architecture/artifact_model.py b/permutect/architecture/artifact_model.py index b8da6cac..bf35ab39 100644 --- a/permutect/architecture/artifact_model.py +++ b/permutect/architecture/artifact_model.py @@ -389,9 +389,9 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_ # note that for metrics we use batch_cpu correct = ((pred > 0) == (batch_cpu.labels > 0.5)).tolist() - for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, datum, weight in zip( + for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, variant, weight in zip( batch_cpu.variant_types(), pred.tolist(), batch_cpu.labels.tolist(), batch_cpu.is_labeled_mask.tolist(), correct, - batch_cpu.alt_counts, batch_cpu.original_data, weights.tolist()): + batch_cpu.alt_counts, batch_cpu.original_variants, weights.tolist()): if is_labeled < 0.5: # we only evaluate labeled data continue evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call, alt_count, weight) @@ -408,7 +408,6 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_ pqueue.get() # discards the least confident bad call if not pqueue.full(): # if space was cleared or if it wasn't full already - variant = datum.get_other_stuff_1d().get_variant() pqueue.put((confidence, str(variant.contig) + ":" + str( variant.position) + ':' + variant.ref + "->" + variant.alt)) # done with this epoch type diff --git a/permutect/data/base_datum.py b/permutect/data/base_datum.py index 202667f0..7b911b2c 100644 --- a/permutect/data/base_datum.py +++ b/permutect/data/base_datum.py @@ -661,7 +661,6 @@ def is_labeled(self): class ArtifactBatch: def __init__(self, data: List[ArtifactDatum]): - #self.original_data = data self.original_variants = [d.get_other_stuff_1d().get_variant() for d in data] self.original_counts_and_seq_lks = [d.get_other_stuff_1d().get_counts_and_seq_lks() for d in data] diff --git a/permutect/tools/filter_variants.py b/permutect/tools/filter_variants.py index 4d9f6ca4..3415e3f8 100644 --- a/permutect/tools/filter_variants.py +++ b/permutect/tools/filter_variants.py @@ -222,10 +222,12 @@ def make_posterior_data_loader(dataset_file, input_vcf, contig_index_to_name_map labels = [(Label.ARTIFACT if label > 0.5 else Label.VARIANT) if is_labeled > 0.5 else Label.UNLABELED for (label, is_labeled) in zip(artifact_batch.labels, artifact_batch.is_labeled_mask)] - for artifact_datum, logit, label, embedding in zip(artifact_batch.original_data, artifact_logits.detach().tolist(), labels, artifact_batch.get_representations_2d().cpu()): + for variant,counts_and_seq_lks, logit, label, embedding in zip(artifact_batch.original_variants, + artifact_batch.original_counts_and_seq_lks, + artifact_logits.detach().tolist(), + labels, + artifact_batch.get_representations_2d().cpu()): m += 1 # DEBUG - variant = artifact_datum.get_other_stuff_1d().get_variant() - counts_and_seq_lks = artifact_datum.get_other_stuff_1d().get_counts_and_seq_lks() contig_name = contig_index_to_name_map[variant.contig] encoding = encode(contig_name, variant.position, variant.ref, variant.alt) if encoding in allele_frequencies and encoding not in m2_filtering_to_keep: From b89ed9136524d5933cd4c24f24878e8d5f96c43d Mon Sep 17 00:00:00 2001 From: fatsmcgee Date: Sat, 2 Nov 2024 14:32:09 -0400 Subject: [PATCH 5/6] Use larger batch size for inference when learning is not happening --- permutect/architecture/artifact_model.py | 15 ++++++++++----- permutect/constants.py | 1 + permutect/parameters.py | 9 +++++++-- permutect/tools/train_model.py | 5 ++++- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/permutect/architecture/artifact_model.py b/permutect/architecture/artifact_model.py index bf35ab39..477e1b6a 100644 --- a/permutect/architecture/artifact_model.py +++ b/permutect/architecture/artifact_model.py @@ -224,9 +224,12 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s print(f"Is CUDA available? {is_cuda}") validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold - train_loader = dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), training_params.batch_size, is_cuda, training_params.num_workers) + def create_train_loader(batch_size): + return dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), batch_size, is_cuda, training_params.num_workers) + train_loader = create_train_loader(training_params.batch_size) + train_loader_inference_only = create_train_loader(training_params.inference_batch_size) print(f"Train loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}") - valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.batch_size, is_cuda, training_params.num_workers) + valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.inference_batch_size, is_cuda, training_params.num_workers) print(f"Validation loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}") first_epoch, last_epoch = 1, training_params.num_epochs + training_params.num_calibration_epochs @@ -331,11 +334,11 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s print(f"End of epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}, time elapsed(s): {time.time() - start_of_epoch:.2f}") is_last = (epoch == last_epoch) if (epochs_per_evaluation is not None and epoch % epochs_per_evaluation == 0) or is_last: - self.evaluate_model(epoch, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=False, report_worst=False) + self.evaluate_model(epoch, dataset, train_loader_inference_only, valid_loader, summary_writer, collect_embeddings=False, report_worst=False) if is_last: # collect data in order to do final calibration print("collecting data for final calibration") - evaluation_metrics, _ = self.collect_evaluation_data(dataset, train_loader, valid_loader, report_worst=False) + evaluation_metrics, _ = self.collect_evaluation_data(dataset, train_loader_inference_only, valid_loader, report_worst=False) logit_adjustments_by_var_type_and_count_bin = evaluation_metrics.metrics[Epoch.VALID].calculate_logit_adjustments(use_harmonic_mean=False) print("here are the logit adjustments:") @@ -354,7 +357,7 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s # consider this an extra post-postprocessing/final calibration epoch, hence epoch+1 print("doing one final evaluation after the last logit adjustment") - self.evaluate_model(epoch + 1, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) + self.evaluate_model(epoch + 1, dataset, train_loader_inference_only, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) # note that we have not learned the AF spectrum yet # done with training @@ -364,6 +367,7 @@ def evaluate_model_after_training(self, dataset: ArtifactDataset, batch_size, nu valid_loader = dataset.make_data_loader(dataset.last_fold_only(), batch_size, self._device.type == 'cuda', num_workers) self.evaluate_model(None, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) + @torch.inference_mode() def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_loader, report_worst: bool): # the keys are tuples of (true label -- 1 for variant, 0 for artifact; rounded alt count) worst_offenders_by_truth_and_alt_count = defaultdict(lambda: PriorityQueue(WORST_OFFENDERS_QUEUE_SIZE)) @@ -414,6 +418,7 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_ # done collecting data return evaluation_metrics, worst_offenders_by_truth_and_alt_count + @torch.inference_mode() def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, valid_loader, summary_writer: SummaryWriter, collect_embeddings: bool = False, report_worst: bool = False): diff --git a/permutect/constants.py b/permutect/constants.py index c548a95d..18b79371 100644 --- a/permutect/constants.py +++ b/permutect/constants.py @@ -44,6 +44,7 @@ CHUNK_SIZE_NAME = 'chunk_size' NUM_EPOCHS_NAME = 'num_epochs' NUM_CALIBRATION_EPOCHS_NAME = 'num_calibration_epochs' +INFERENCE_BATCH_SIZE_NAME = 'inference_batch_size' NUM_WORKERS_NAME = 'num_workers' NUM_SPECTRUM_ITERATIONS_NAME = 'num_spectrum_iterations' SPECTRUM_LEARNING_RATE_NAME = 'spectrum_learning_rate' diff --git a/permutect/parameters.py b/permutect/parameters.py index 41fcace0..246aa7f7 100644 --- a/permutect/parameters.py +++ b/permutect/parameters.py @@ -84,13 +84,15 @@ def add_base_model_params_to_parser(parser): # common parameters for training models class TrainingParameters: def __init__(self, batch_size: int, num_epochs: int, learning_rate: float = 0.001, - weight_decay: float = 0.01, num_workers: int = 0, num_calibration_epochs: int = 0): + weight_decay: float = 0.01, num_workers: int = 0, num_calibration_epochs: int = 0, + inference_batch_size: int = 8192): self.batch_size = batch_size self.num_epochs = num_epochs self.learning_rate = learning_rate self.weight_decay = weight_decay self.num_workers = num_workers self.num_calibration_epochs = num_calibration_epochs + self.inference_batch_size = inference_batch_size def parse_training_params(args) -> TrainingParameters: @@ -100,7 +102,8 @@ def parse_training_params(args) -> TrainingParameters: num_epochs = getattr(args, constants.NUM_EPOCHS_NAME) num_calibration_epochs = getattr(args, constants.NUM_CALIBRATION_EPOCHS_NAME) num_workers = getattr(args, constants.NUM_WORKERS_NAME) - return TrainingParameters(batch_size, num_epochs, learning_rate, weight_decay, num_workers, num_calibration_epochs) + inference_batch_size = getattr(args, constants.INFERENCE_BATCH_SIZE_NAME) + return TrainingParameters(batch_size, num_epochs, learning_rate, weight_decay, num_workers, num_calibration_epochs, inference_batch_size) def add_training_params_to_parser(parser): @@ -117,6 +120,8 @@ def add_training_params_to_parser(parser): help='number of epochs for primary training loop') parser.add_argument('--' + constants.NUM_CALIBRATION_EPOCHS_NAME, type=int, default=0, required=False, help='number of calibration-only epochs') + parser.add_argument('--' + constants.INFERENCE_BATCH_SIZE_NAME, type=int, default=8192, required=False, + help='batch size when performing model inference (not training)') class ArtifactModelParameters: diff --git a/permutect/tools/train_model.py b/permutect/tools/train_model.py index d21bbba7..7465a8d0 100644 --- a/permutect/tools/train_model.py +++ b/permutect/tools/train_model.py @@ -97,7 +97,10 @@ def main_without_parsing(args): print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}") base_dataset = BaseDataset(data_tarfile=getattr(args, constants.TRAIN_TAR_NAME), num_folds=10) print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") - artifact_dataset = ArtifactDataset(base_dataset, base_model, base_loader_num_workers=training_params.num_workers) + artifact_dataset = ArtifactDataset(base_dataset, + base_model, + base_loader_num_workers=training_params.num_workers, + base_loader_batch_size=training_params.inference_batch_size) print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}") model = train_artifact_model(hyperparams=params, training_params=training_params, summary_writer=summary_writer, dataset=artifact_dataset) From 7e5dfffd175680fd89856229536b275910002bc9 Mon Sep 17 00:00:00 2001 From: fatsmcgee Date: Sat, 2 Nov 2024 22:35:57 -0400 Subject: [PATCH 6/6] For now, get rid of inference-only train loader optimization due to memory issues --- permutect/architecture/artifact_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/permutect/architecture/artifact_model.py b/permutect/architecture/artifact_model.py index 477e1b6a..150ad8e4 100644 --- a/permutect/architecture/artifact_model.py +++ b/permutect/architecture/artifact_model.py @@ -224,10 +224,7 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s print(f"Is CUDA available? {is_cuda}") validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold - def create_train_loader(batch_size): - return dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), batch_size, is_cuda, training_params.num_workers) - train_loader = create_train_loader(training_params.batch_size) - train_loader_inference_only = create_train_loader(training_params.inference_batch_size) + train_loader = dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), training_params.batch_size, is_cuda, training_params.num_workers) print(f"Train loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}") valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.inference_batch_size, is_cuda, training_params.num_workers) print(f"Validation loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}") @@ -334,11 +331,12 @@ def create_train_loader(batch_size): print(f"End of epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}, time elapsed(s): {time.time() - start_of_epoch:.2f}") is_last = (epoch == last_epoch) if (epochs_per_evaluation is not None and epoch % epochs_per_evaluation == 0) or is_last: - self.evaluate_model(epoch, dataset, train_loader_inference_only, valid_loader, summary_writer, collect_embeddings=False, report_worst=False) + print(f"performing evaluation on epoch {epoch}") + self.evaluate_model(epoch, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=False, report_worst=False) if is_last: # collect data in order to do final calibration print("collecting data for final calibration") - evaluation_metrics, _ = self.collect_evaluation_data(dataset, train_loader_inference_only, valid_loader, report_worst=False) + evaluation_metrics, _ = self.collect_evaluation_data(dataset, train_loader, valid_loader, report_worst=False) logit_adjustments_by_var_type_and_count_bin = evaluation_metrics.metrics[Epoch.VALID].calculate_logit_adjustments(use_harmonic_mean=False) print("here are the logit adjustments:") @@ -357,7 +355,7 @@ def create_train_loader(batch_size): # consider this an extra post-postprocessing/final calibration epoch, hence epoch+1 print("doing one final evaluation after the last logit adjustment") - self.evaluate_model(epoch + 1, dataset, train_loader_inference_only, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) + self.evaluate_model(epoch + 1, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=True, report_worst=True) # note that we have not learned the AF spectrum yet # done with training