Skip to content

Commit

Permalink
Base model and artifact model are saved together after artifact model…
Browse files Browse the repository at this point in the history
… training (#149)
  • Loading branch information
davidbenjamin authored Oct 10, 2024
1 parent f224112 commit 72b6a08
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 47 deletions.
55 changes: 36 additions & 19 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from itertools import chain
from matplotlib import pyplot as plt

from permutect.architecture.base_model import calculate_batch_weights
from permutect.architecture.base_model import calculate_batch_weights, BaseModel, base_model_from_saved_dict
from permutect.architecture.mlp import MLP
from permutect.architecture.monotonic import MonoDense
from permutect.data.base_datum import ArtifactBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT
Expand Down Expand Up @@ -374,27 +374,44 @@ def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, val

# done collecting data

def save(self, path, artifact_log_priors, artifact_spectra):
torch.save({
constants.STATE_DICT_NAME: self.state_dict(),
constants.NUM_BASE_FEATURES_NAME: self.num_base_features,
constants.NUM_REF_ALT_FEATURES_NAME: self.num_ref_alt_features,
constants.HYPERPARAMS_NAME: self.params,
constants.ARTIFACT_LOG_PRIORS_NAME: artifact_log_priors,
constants.ARTIFACT_SPECTRA_STATE_DICT_NAME: artifact_spectra.state_dict()
}, path)
def make_dict_for_saving(self, artifact_log_priors, artifact_spectra, prefix: str = "artifact"):
return {(prefix + constants.STATE_DICT_NAME): self.state_dict(),
(prefix + constants.NUM_BASE_FEATURES_NAME): self.num_base_features,
(prefix + constants.NUM_REF_ALT_FEATURES_NAME): self.num_ref_alt_features,
(prefix + constants.HYPERPARAMS_NAME): self.params,
(prefix + constants.ARTIFACT_LOG_PRIORS_NAME): artifact_log_priors,
(prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME): artifact_spectra.state_dict()}

def save(self, path, artifact_log_priors, artifact_spectra, prefix: str = "artifact"):
torch.save(self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix), path)

# log artifact priors and artifact spectra may be None
def load_artifact_model(path) -> ArtifactModel:
saved = torch.load(path)
model_params = saved[constants.HYPERPARAMS_NAME]
num_base_features = saved[constants.NUM_BASE_FEATURES_NAME]
num_ref_alt_features = saved[constants.NUM_REF_ALT_FEATURES_NAME]
def save_with_base_model(self, base_model: BaseModel, path, artifact_log_priors, artifact_spectra):
artifact_dict = self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix="artifact")
base_dict = base_model.make_dict_for_saving(prefix="base")
torch.save({**artifact_dict, **base_dict}, path)


def artifact_model_from_saved_dict(saved, prefix: str = "artifact"):
model_params = saved[prefix + constants.HYPERPARAMS_NAME]
num_base_features = saved[prefix + constants.NUM_BASE_FEATURES_NAME]
num_ref_alt_features = saved[prefix + constants.NUM_REF_ALT_FEATURES_NAME]
model = ArtifactModel(model_params, num_base_features, num_ref_alt_features)
model.load_state_dict(saved[constants.STATE_DICT_NAME])
model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME])

artifact_log_priors = saved[constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None
artifact_spectra_state_dict = saved[constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] #possibly None
artifact_log_priors = saved[prefix + constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None
artifact_spectra_state_dict = saved[prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] # possibly None
return model, artifact_log_priors, artifact_spectra_state_dict


# log artifact priors and artifact spectra may be None
def load_artifact_model(path, prefix: str = "artifact") -> ArtifactModel:
saved = torch.load(path)
return artifact_model_from_saved_dict(saved, prefix)


def load_base_model_and_artifact_model(path) -> ArtifactModel:
saved = torch.load(path)
base_model = base_model_from_saved_dict(saved, prefix="base")
artifact_model, artifact_log_priors, artifact_spectra = artifact_model_from_saved_dict(saved, prefix="artifact")
return base_model, artifact_model, artifact_log_priors, artifact_spectra

33 changes: 19 additions & 14 deletions permutect/architecture/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,26 @@ def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) -

return result_ve, ref_seq_embeddings_ve # ref seq embeddings are useful later

def make_dict_for_saving(self, prefix: str = ""):
return {(prefix + constants.STATE_DICT_NAME): self.state_dict(),
(prefix + constants.HYPERPARAMS_NAME): self._params,
(prefix + constants.NUM_READ_FEATURES_NAME): self.read_embedding.input_dimension(),
(prefix + constants.NUM_INFO_FEATURES_NAME): self.info_embedding.input_dimension(),
(prefix + constants.REF_SEQUENCE_LENGTH_NAME): self.ref_sequence_length()}

def save(self, path):
torch.save({
constants.STATE_DICT_NAME: self.state_dict(),
constants.HYPERPARAMS_NAME: self._params,
constants.NUM_READ_FEATURES_NAME: self.read_embedding.input_dimension(),
constants.NUM_INFO_FEATURES_NAME: self.info_embedding.input_dimension(),
constants.REF_SEQUENCE_LENGTH_NAME: self.ref_sequence_length()
}, path)
torch.save(self.make_dict_for_saving(), path)


def load_base_model(path, device: torch.device = utils.gpu_if_available()) -> BaseModel:
saved = torch.load(path)
hyperparams = saved[constants.HYPERPARAMS_NAME]
num_read_features = saved[constants.NUM_READ_FEATURES_NAME]
num_info_features = saved[constants.NUM_INFO_FEATURES_NAME]
ref_sequence_length = saved[constants.REF_SEQUENCE_LENGTH_NAME]
def base_model_from_saved_dict(saved, prefix: str = "", device: torch.device = utils.gpu_if_available()):
hyperparams = saved[prefix + constants.HYPERPARAMS_NAME]
num_read_features = saved[prefix + constants.NUM_READ_FEATURES_NAME]
num_info_features = saved[prefix + constants.NUM_INFO_FEATURES_NAME]
ref_sequence_length = saved[prefix + constants.REF_SEQUENCE_LENGTH_NAME]

model = BaseModel(hyperparams, num_read_features=num_read_features, num_info_features=num_info_features,
ref_sequence_length=ref_sequence_length, device=device)
model.load_state_dict(saved[constants.STATE_DICT_NAME])
model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME])

# in case the state dict had the wrong dtype for the device we're on now eg base model was pretrained on GPU
# and we're now on CPU
Expand All @@ -212,6 +212,11 @@ def load_base_model(path, device: torch.device = utils.gpu_if_available()) -> Ba
return model


def load_base_model(path, prefix: str = "", device: torch.device = utils.gpu_if_available()) -> BaseModel:
saved = torch.load(path)
return base_model_from_saved_dict(saved, prefix, device)


# outputs a 1D tensor of losses over the batch. We assume it needs the representations of the batch data from the base
# model. We nonetheless also use the model as an input because there are some learning strategies that involve
# computing representations of a modified batch.
Expand Down
Empty file.
25 changes: 25 additions & 0 deletions permutect/architecture/gradient_reversal/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any

from torch.autograd import Function


class GradientReversal(Function):
@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
pass

@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x

@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = - alpha * grad_output
return grad_input, None


revgrad = GradientReversal.apply
12 changes: 12 additions & 0 deletions permutect/architecture/gradient_reversal/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .functional import revgrad
import torch
from torch import nn


class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)

def forward(self, x):
return revgrad(x, self.alpha)
2 changes: 0 additions & 2 deletions permutect/test/tools/test_filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

def test_filtering_on_dream1_chr20():
# Inputs
base_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/base-model.pt'
artifact_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/artifact-model.pt'

mutect2_vcf = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/mutect2_chr20.vcf'
Expand All @@ -22,7 +21,6 @@ def test_filtering_on_dream1_chr20():
filtering_args = Namespace()
setattr(filtering_args, constants.INPUT_NAME, mutect2_vcf)
setattr(filtering_args, constants.TEST_DATASET_NAME, filtering_dataset)
setattr(filtering_args, constants.BASE_MODEL_NAME, base_model)
setattr(filtering_args, constants.M3_MODEL_NAME, artifact_model)
setattr(filtering_args, constants.OUTPUT_NAME, permutect_vcf.name)
setattr(filtering_args, constants.TENSORBOARD_DIR_NAME, tensorboard_dir.name)
Expand Down
13 changes: 6 additions & 7 deletions permutect/tools/filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm.autonotebook import tqdm

from permutect import constants
from permutect.architecture.artifact_model import ArtifactModel, load_artifact_model
from permutect.architecture.artifact_model import ArtifactModel, load_artifact_model, load_base_model_and_artifact_model
from permutect.architecture.posterior_model import PosteriorModel
from permutect.architecture.base_model import BaseModel, load_base_model
from permutect.data import base_dataset, plain_text_data, base_datum
Expand Down Expand Up @@ -73,8 +73,7 @@ def parse_arguments():
parser.add_argument('--' + constants.INPUT_NAME, required=True, help='unfiltered input Mutect2 VCF')
parser.add_argument('--' + constants.TEST_DATASET_NAME, required=True,
help='plain text dataset file corresponding to variants in input VCF')
parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Mutect3 artifact model from train_model.py')
parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py')
parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Permutect model from train_model.py')
parser.add_argument('--' + constants.CONTIGS_TABLE_NAME, required=True, help='table of contig names vs integer indices')
parser.add_argument('--' + constants.OUTPUT_NAME, required=True, help='path to output filtered VCF')
parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard')
Expand Down Expand Up @@ -132,7 +131,6 @@ def get_segmentation(segments_file) -> defaultdict:

def main_without_parsing(args):
make_filtered_vcf(saved_artifact_model_path=getattr(args, constants.M3_MODEL_NAME),
base_model_path=getattr(args, constants.BASE_MODEL_NAME),
initial_log_variant_prior=getattr(args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME),
initial_log_artifact_prior=getattr(args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME),
test_dataset_file=getattr(args, constants.TEST_DATASET_NAME),
Expand All @@ -152,19 +150,20 @@ def main_without_parsing(args):
normal_segmentation=get_segmentation(getattr(args, constants.NORMAL_MAF_SEGMENTS_NAME)))


def make_filtered_vcf(saved_artifact_model_path, base_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float,
def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float,
test_dataset_file, contigs_table, input_vcf, output_vcf, batch_size: int, num_workers: int, chunk_size: int, num_spectrum_iterations: int,
spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False,
segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)):
print("Loading artifact model and test dataset")
base_model = load_base_model(base_model_path)
contig_index_to_name_map = {}
with open(contigs_table) as file:
while line := file.readline().strip():
contig, index = line.split()
contig_index_to_name_map[int(index)] = contig

artifact_model, artifact_log_priors, artifact_spectra_state_dict = load_artifact_model(saved_artifact_model_path)
base_model, artifact_model, artifact_log_priors, artifact_spectra_state_dict = \
load_base_model_and_artifact_model(saved_artifact_model_path)

posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features)
posterior_data_loader = make_posterior_data_loader(test_dataset_file, input_vcf, contig_index_to_name_map,
base_model, artifact_model, batch_size, num_workers=num_workers, chunk_size=chunk_size, segmentation=segmentation, normal_segmentation=normal_segmentation)
Expand Down
2 changes: 1 addition & 1 deletion permutect/tools/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main_without_parsing(args):
summary_writer.add_figure("Artifact AF Spectra", art_spectra_fig)

summary_writer.close()
model.save(getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra)
model.save_with_base_model(base_model, getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra)


def main():
Expand Down
4 changes: 0 additions & 4 deletions scripts/permutect.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/15/pl
workflow Permutect {
input {
File permutect_model
File base_model

File? intervals
File? masks
Expand Down Expand Up @@ -100,7 +99,6 @@ workflow Permutect {
mutect2_vcf = IndexAfterSplitting.vcf,
mutect2_vcf_idx = IndexAfterSplitting.vcf_index,
permutect_model = permutect_model,
base_model = base_model,
test_dataset = select_first([Mutect2.m3_dataset]),
contigs_table = Mutect2.permutect_contigs_table,
maf_segments = Mutect2.maf_segments,
Expand Down Expand Up @@ -133,7 +131,6 @@ workflow Permutect {
task PermutectFiltering {
input {
File permutect_model
File base_model
File test_dataset
File contigs_table
File mutect2_vcf
Expand Down Expand Up @@ -167,7 +164,6 @@ task PermutectFiltering {
filter_variants --input ~{mutect2_vcf} --test_dataset ~{test_dataset} \
--permutect_model ~{permutect_model} \
--base_model ~{base_model} \
--contigs_table ~{contigs_table} \
--output permutect-filtered.vcf \
--tensorboard_dir tensorboard \
Expand Down

0 comments on commit 72b6a08

Please sign in to comment.