Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Artifact Model Training Optimizations #155

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold
train_loader = dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), training_params.batch_size, is_cuda, training_params.num_workers)
print(f"Train loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}")
valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.batch_size, is_cuda, training_params.num_workers)
valid_loader = dataset.make_data_loader([validation_fold_to_use], training_params.inference_batch_size, is_cuda, training_params.num_workers)
print(f"Validation loader created, memory usage percent: {psutil.virtual_memory().percent:.1f}")

first_epoch, last_epoch = 1, training_params.num_epochs + training_params.num_calibration_epochs
Expand Down Expand Up @@ -331,6 +331,7 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
print(f"End of epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}, time elapsed(s): {time.time() - start_of_epoch:.2f}")
is_last = (epoch == last_epoch)
if (epochs_per_evaluation is not None and epoch % epochs_per_evaluation == 0) or is_last:
print(f"performing evaluation on epoch {epoch}")
self.evaluate_model(epoch, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=False, report_worst=False)
if is_last:
# collect data in order to do final calibration
Expand Down Expand Up @@ -364,6 +365,7 @@ def evaluate_model_after_training(self, dataset: ArtifactDataset, batch_size, nu
valid_loader = dataset.make_data_loader(dataset.last_fold_only(), batch_size, self._device.type == 'cuda', num_workers)
self.evaluate_model(None, dataset, train_loader, valid_loader, summary_writer, collect_embeddings=True, report_worst=True)

@torch.inference_mode()
def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_loader, report_worst: bool):
# the keys are tuples of (true label -- 1 for variant, 0 for artifact; rounded alt count)
worst_offenders_by_truth_and_alt_count = defaultdict(lambda: PriorityQueue(WORST_OFFENDERS_QUEUE_SIZE))
Expand All @@ -389,9 +391,9 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
# note that for metrics we use batch_cpu
correct = ((pred > 0) == (batch_cpu.labels > 0.5)).tolist()

for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, datum, weight in zip(
for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, variant, weight in zip(
batch_cpu.variant_types(), pred.tolist(), batch_cpu.labels.tolist(), batch_cpu.is_labeled_mask.tolist(), correct,
batch_cpu.alt_counts, batch_cpu.original_data, weights.tolist()):
batch_cpu.alt_counts, batch_cpu.original_variants, weights.tolist()):
if is_labeled < 0.5: # we only evaluate labeled data
continue
evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call, alt_count, weight)
Expand All @@ -408,13 +410,13 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
pqueue.get() # discards the least confident bad call

if not pqueue.full(): # if space was cleared or if it wasn't full already
variant = datum.get_other_stuff_1d().get_variant()
pqueue.put((confidence, str(variant.contig) + ":" + str(
variant.position) + ':' + variant.ref + "->" + variant.alt))
# done with this epoch type
# done collecting data
return evaluation_metrics, worst_offenders_by_truth_and_alt_count

@torch.inference_mode()
def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, valid_loader, summary_writer: SummaryWriter,
collect_embeddings: bool = False, report_worst: bool = False):

Expand Down
1 change: 1 addition & 0 deletions permutect/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
CHUNK_SIZE_NAME = 'chunk_size'
NUM_EPOCHS_NAME = 'num_epochs'
NUM_CALIBRATION_EPOCHS_NAME = 'num_calibration_epochs'
INFERENCE_BATCH_SIZE_NAME = 'inference_batch_size'
NUM_WORKERS_NAME = 'num_workers'
NUM_SPECTRUM_ITERATIONS_NAME = 'num_spectrum_iterations'
SPECTRUM_LEARNING_RATE_NAME = 'spectrum_learning_rate'
Expand Down
19 changes: 13 additions & 6 deletions permutect/data/artifact_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math
import random
from typing import List
from tqdm.autonotebook import tqdm

import numpy as np
import torch
from tqdm.autonotebook import tqdm
from torch.utils.data import Dataset, DataLoader, Sampler

from permutect.architecture.base_model import BaseModel
Expand All @@ -14,7 +14,11 @@
# given a ReadSetDataset, apply a BaseModel to get an ArtifactDataset (in RAM, maybe implement memory map later)
# of RepresentationReadSets
class ArtifactDataset(Dataset):
def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_use: List[int] = None):
def __init__(self, base_dataset: BaseDataset,
base_model: BaseModel,
folds_to_use: List[int] = None,
base_loader_num_workers=0,
base_loader_batch_size=8192):
self.counts_by_source = base_dataset.counts_by_source
self.totals = base_dataset.totals
self.source_totals = base_dataset.source_totals
Expand All @@ -30,16 +34,19 @@ def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_us

index = 0

loader = base_dataset.make_data_loader(base_dataset.all_folds() if folds_to_use is None else folds_to_use, batch_size=256)
loader = base_dataset.make_data_loader(base_dataset.all_folds() if folds_to_use is None else folds_to_use,
batch_size=base_loader_batch_size,
num_workers=base_loader_num_workers)
print("making artifact dataset from base dataset")

is_cuda = base_model._device.type == 'cuda'
print(f"Is base model using CUDA? {is_cuda}")

pbar = tqdm(enumerate(loader), mininterval=60)
for n, base_batch_cpu in pbar:
base_batch = base_batch_cpu.copy_to(base_model._device, non_blocking=base_model._device.type == 'cuda')
representations, ref_alt_seq_embeddings = base_model.calculate_representations(base_batch)
base_batch = base_batch_cpu.copy_to(base_model._device, non_blocking=is_cuda)
with torch.inference_mode():
representations, ref_alt_seq_embeddings = base_model.calculate_representations(base_batch)
for representation, ref_alt_emb, base_datum in zip(representations.detach().cpu(), ref_alt_seq_embeddings.detach().cpu(), base_batch_cpu.original_list()):
artifact_datum = ArtifactDatum(base_datum, representation.detach(), ref_alt_emb)
self.artifact_data.append(artifact_datum)
Expand Down
4 changes: 3 additions & 1 deletion permutect/data/base_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,9 @@ def is_labeled(self):

class ArtifactBatch:
def __init__(self, data: List[ArtifactDatum]):
self.original_data = data

self.original_variants = [d.get_other_stuff_1d().get_variant() for d in data]
self.original_counts_and_seq_lks = [d.get_other_stuff_1d().get_counts_and_seq_lks() for d in data]

self.representations_2d = torch.vstack([item.representation for item in data])
self.ref_alt_seq_embeddings_2d = torch.vstack([item.ref_alt_seq_embedding for item in data])
Expand Down
9 changes: 7 additions & 2 deletions permutect/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ def add_base_model_params_to_parser(parser):
# common parameters for training models
class TrainingParameters:
def __init__(self, batch_size: int, num_epochs: int, learning_rate: float = 0.001,
weight_decay: float = 0.01, num_workers: int = 0, num_calibration_epochs: int = 0):
weight_decay: float = 0.01, num_workers: int = 0, num_calibration_epochs: int = 0,
inference_batch_size: int = 8192):
self.batch_size = batch_size
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.num_workers = num_workers
self.num_calibration_epochs = num_calibration_epochs
self.inference_batch_size = inference_batch_size


def parse_training_params(args) -> TrainingParameters:
Expand All @@ -100,7 +102,8 @@ def parse_training_params(args) -> TrainingParameters:
num_epochs = getattr(args, constants.NUM_EPOCHS_NAME)
num_calibration_epochs = getattr(args, constants.NUM_CALIBRATION_EPOCHS_NAME)
num_workers = getattr(args, constants.NUM_WORKERS_NAME)
return TrainingParameters(batch_size, num_epochs, learning_rate, weight_decay, num_workers, num_calibration_epochs)
inference_batch_size = getattr(args, constants.INFERENCE_BATCH_SIZE_NAME)
return TrainingParameters(batch_size, num_epochs, learning_rate, weight_decay, num_workers, num_calibration_epochs, inference_batch_size)


def add_training_params_to_parser(parser):
Expand All @@ -117,6 +120,8 @@ def add_training_params_to_parser(parser):
help='number of epochs for primary training loop')
parser.add_argument('--' + constants.NUM_CALIBRATION_EPOCHS_NAME, type=int, default=0, required=False,
help='number of calibration-only epochs')
parser.add_argument('--' + constants.INFERENCE_BATCH_SIZE_NAME, type=int, default=8192, required=False,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also need to go in the WDL scripts, but I'll handle that in a follow-up PR.

help='batch size when performing model inference (not training)')


class ArtifactModelParameters:
Expand Down
8 changes: 5 additions & 3 deletions permutect/tools/filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def make_posterior_data_loader(dataset_file, input_vcf, contig_index_to_name_map

labels = [(Label.ARTIFACT if label > 0.5 else Label.VARIANT) if is_labeled > 0.5 else Label.UNLABELED for (label, is_labeled) in zip(artifact_batch.labels, artifact_batch.is_labeled_mask)]

for artifact_datum, logit, label, embedding in zip(artifact_batch.original_data, artifact_logits.detach().tolist(), labels, artifact_batch.get_representations_2d().cpu()):
for variant,counts_and_seq_lks, logit, label, embedding in zip(artifact_batch.original_variants,
artifact_batch.original_counts_and_seq_lks,
artifact_logits.detach().tolist(),
labels,
artifact_batch.get_representations_2d().cpu()):
m += 1 # DEBUG
variant = artifact_datum.get_other_stuff_1d().get_variant()
counts_and_seq_lks = artifact_datum.get_other_stuff_1d().get_counts_and_seq_lks()
contig_name = contig_index_to_name_map[variant.contig]
encoding = encode(contig_name, variant.position, variant.ref, variant.alt)
if encoding in allele_frequencies and encoding not in m2_filtering_to_keep:
Expand Down
5 changes: 4 additions & 1 deletion permutect/tools/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def main_without_parsing(args):
print(f"Memory usage percent before creating BaseDataset: {psutil.virtual_memory().percent:.1f}")
base_dataset = BaseDataset(data_tarfile=getattr(args, constants.TRAIN_TAR_NAME), num_folds=10)
print(f"Memory usage percent before creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}")
artifact_dataset = ArtifactDataset(base_dataset, base_model)
artifact_dataset = ArtifactDataset(base_dataset,
base_model,
base_loader_num_workers=training_params.num_workers,
base_loader_batch_size=training_params.inference_batch_size)
print(f"Memory usage percent after creating ArtifactDataset: {psutil.virtual_memory().percent:.1f}")

model = train_artifact_model(hyperparams=params, training_params=training_params, summary_writer=summary_writer, dataset=artifact_dataset)
Expand Down