Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base model and artifact model are saved together after artifact model training #149

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from itertools import chain
from matplotlib import pyplot as plt

from permutect.architecture.base_model import calculate_batch_weights
from permutect.architecture.base_model import calculate_batch_weights, BaseModel, base_model_from_saved_dict
from permutect.architecture.mlp import MLP
from permutect.architecture.monotonic import MonoDense
from permutect.data.base_datum import ArtifactBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT
Expand Down Expand Up @@ -374,27 +374,44 @@ def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, val

# done collecting data

def save(self, path, artifact_log_priors, artifact_spectra):
torch.save({
constants.STATE_DICT_NAME: self.state_dict(),
constants.NUM_BASE_FEATURES_NAME: self.num_base_features,
constants.NUM_REF_ALT_FEATURES_NAME: self.num_ref_alt_features,
constants.HYPERPARAMS_NAME: self.params,
constants.ARTIFACT_LOG_PRIORS_NAME: artifact_log_priors,
constants.ARTIFACT_SPECTRA_STATE_DICT_NAME: artifact_spectra.state_dict()
}, path)
def make_dict_for_saving(self, artifact_log_priors, artifact_spectra, prefix: str = "artifact"):
return {(prefix + constants.STATE_DICT_NAME): self.state_dict(),
(prefix + constants.NUM_BASE_FEATURES_NAME): self.num_base_features,
(prefix + constants.NUM_REF_ALT_FEATURES_NAME): self.num_ref_alt_features,
(prefix + constants.HYPERPARAMS_NAME): self.params,
(prefix + constants.ARTIFACT_LOG_PRIORS_NAME): artifact_log_priors,
(prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME): artifact_spectra.state_dict()}

def save(self, path, artifact_log_priors, artifact_spectra, prefix: str = "artifact"):
torch.save(self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix), path)

# log artifact priors and artifact spectra may be None
def load_artifact_model(path) -> ArtifactModel:
saved = torch.load(path)
model_params = saved[constants.HYPERPARAMS_NAME]
num_base_features = saved[constants.NUM_BASE_FEATURES_NAME]
num_ref_alt_features = saved[constants.NUM_REF_ALT_FEATURES_NAME]
def save_with_base_model(self, base_model: BaseModel, path, artifact_log_priors, artifact_spectra):
artifact_dict = self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix="artifact")
base_dict = base_model.make_dict_for_saving(prefix="base")
torch.save({**artifact_dict, **base_dict}, path)


def artifact_model_from_saved_dict(saved, prefix: str = "artifact"):
model_params = saved[prefix + constants.HYPERPARAMS_NAME]
num_base_features = saved[prefix + constants.NUM_BASE_FEATURES_NAME]
num_ref_alt_features = saved[prefix + constants.NUM_REF_ALT_FEATURES_NAME]
model = ArtifactModel(model_params, num_base_features, num_ref_alt_features)
model.load_state_dict(saved[constants.STATE_DICT_NAME])
model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME])

artifact_log_priors = saved[constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None
artifact_spectra_state_dict = saved[constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] #possibly None
artifact_log_priors = saved[prefix + constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None
artifact_spectra_state_dict = saved[prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] # possibly None
return model, artifact_log_priors, artifact_spectra_state_dict


# log artifact priors and artifact spectra may be None
def load_artifact_model(path, prefix: str = "artifact") -> ArtifactModel:
saved = torch.load(path)
return artifact_model_from_saved_dict(saved, prefix)


def load_base_model_and_artifact_model(path) -> ArtifactModel:
saved = torch.load(path)
base_model = base_model_from_saved_dict(saved, prefix="base")
artifact_model, artifact_log_priors, artifact_spectra = artifact_model_from_saved_dict(saved, prefix="artifact")
return base_model, artifact_model, artifact_log_priors, artifact_spectra

33 changes: 19 additions & 14 deletions permutect/architecture/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,26 @@ def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) -

return result_ve, ref_seq_embeddings_ve # ref seq embeddings are useful later

def make_dict_for_saving(self, prefix: str = ""):
return {(prefix + constants.STATE_DICT_NAME): self.state_dict(),
(prefix + constants.HYPERPARAMS_NAME): self._params,
(prefix + constants.NUM_READ_FEATURES_NAME): self.read_embedding.input_dimension(),
(prefix + constants.NUM_INFO_FEATURES_NAME): self.info_embedding.input_dimension(),
(prefix + constants.REF_SEQUENCE_LENGTH_NAME): self.ref_sequence_length()}

def save(self, path):
torch.save({
constants.STATE_DICT_NAME: self.state_dict(),
constants.HYPERPARAMS_NAME: self._params,
constants.NUM_READ_FEATURES_NAME: self.read_embedding.input_dimension(),
constants.NUM_INFO_FEATURES_NAME: self.info_embedding.input_dimension(),
constants.REF_SEQUENCE_LENGTH_NAME: self.ref_sequence_length()
}, path)
torch.save(self.make_dict_for_saving(), path)


def load_base_model(path, device: torch.device = utils.gpu_if_available()) -> BaseModel:
saved = torch.load(path)
hyperparams = saved[constants.HYPERPARAMS_NAME]
num_read_features = saved[constants.NUM_READ_FEATURES_NAME]
num_info_features = saved[constants.NUM_INFO_FEATURES_NAME]
ref_sequence_length = saved[constants.REF_SEQUENCE_LENGTH_NAME]
def base_model_from_saved_dict(saved, prefix: str = "", device: torch.device = utils.gpu_if_available()):
hyperparams = saved[prefix + constants.HYPERPARAMS_NAME]
num_read_features = saved[prefix + constants.NUM_READ_FEATURES_NAME]
num_info_features = saved[prefix + constants.NUM_INFO_FEATURES_NAME]
ref_sequence_length = saved[prefix + constants.REF_SEQUENCE_LENGTH_NAME]

model = BaseModel(hyperparams, num_read_features=num_read_features, num_info_features=num_info_features,
ref_sequence_length=ref_sequence_length, device=device)
model.load_state_dict(saved[constants.STATE_DICT_NAME])
model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME])

# in case the state dict had the wrong dtype for the device we're on now eg base model was pretrained on GPU
# and we're now on CPU
Expand All @@ -212,6 +212,11 @@ def load_base_model(path, device: torch.device = utils.gpu_if_available()) -> Ba
return model


def load_base_model(path, prefix: str = "", device: torch.device = utils.gpu_if_available()) -> BaseModel:
saved = torch.load(path)
return base_model_from_saved_dict(saved, prefix, device)


# outputs a 1D tensor of losses over the batch. We assume it needs the representations of the batch data from the base
# model. We nonetheless also use the model as an input because there are some learning strategies that involve
# computing representations of a modified batch.
Expand Down
Empty file.
25 changes: 25 additions & 0 deletions permutect/architecture/gradient_reversal/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any

from torch.autograd import Function


class GradientReversal(Function):
@staticmethod
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
pass

@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return x

@staticmethod
def backward(ctx, grad_output):
grad_input = None
_, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = - alpha * grad_output
return grad_input, None


revgrad = GradientReversal.apply
12 changes: 12 additions & 0 deletions permutect/architecture/gradient_reversal/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .functional import revgrad
import torch
from torch import nn


class GradientReversal(nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = torch.tensor(alpha, requires_grad=False)

def forward(self, x):
return revgrad(x, self.alpha)
2 changes: 0 additions & 2 deletions permutect/test/tools/test_filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

def test_filtering_on_dream1_chr20():
# Inputs
base_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/base-model.pt'
artifact_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/artifact-model.pt'

mutect2_vcf = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/mutect2_chr20.vcf'
Expand All @@ -22,7 +21,6 @@ def test_filtering_on_dream1_chr20():
filtering_args = Namespace()
setattr(filtering_args, constants.INPUT_NAME, mutect2_vcf)
setattr(filtering_args, constants.TEST_DATASET_NAME, filtering_dataset)
setattr(filtering_args, constants.BASE_MODEL_NAME, base_model)
setattr(filtering_args, constants.M3_MODEL_NAME, artifact_model)
setattr(filtering_args, constants.OUTPUT_NAME, permutect_vcf.name)
setattr(filtering_args, constants.TENSORBOARD_DIR_NAME, tensorboard_dir.name)
Expand Down
13 changes: 6 additions & 7 deletions permutect/tools/filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tqdm.autonotebook import tqdm

from permutect import constants
from permutect.architecture.artifact_model import ArtifactModel, load_artifact_model
from permutect.architecture.artifact_model import ArtifactModel, load_artifact_model, load_base_model_and_artifact_model
from permutect.architecture.posterior_model import PosteriorModel
from permutect.architecture.base_model import BaseModel, load_base_model
from permutect.data import base_dataset, plain_text_data, base_datum
Expand Down Expand Up @@ -73,8 +73,7 @@ def parse_arguments():
parser.add_argument('--' + constants.INPUT_NAME, required=True, help='unfiltered input Mutect2 VCF')
parser.add_argument('--' + constants.TEST_DATASET_NAME, required=True,
help='plain text dataset file corresponding to variants in input VCF')
parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Mutect3 artifact model from train_model.py')
parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py')
parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Permutect model from train_model.py')
parser.add_argument('--' + constants.CONTIGS_TABLE_NAME, required=True, help='table of contig names vs integer indices')
parser.add_argument('--' + constants.OUTPUT_NAME, required=True, help='path to output filtered VCF')
parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard')
Expand Down Expand Up @@ -132,7 +131,6 @@ def get_segmentation(segments_file) -> defaultdict:

def main_without_parsing(args):
make_filtered_vcf(saved_artifact_model_path=getattr(args, constants.M3_MODEL_NAME),
base_model_path=getattr(args, constants.BASE_MODEL_NAME),
initial_log_variant_prior=getattr(args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME),
initial_log_artifact_prior=getattr(args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME),
test_dataset_file=getattr(args, constants.TEST_DATASET_NAME),
Expand All @@ -152,19 +150,20 @@ def main_without_parsing(args):
normal_segmentation=get_segmentation(getattr(args, constants.NORMAL_MAF_SEGMENTS_NAME)))


def make_filtered_vcf(saved_artifact_model_path, base_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float,
def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float,
test_dataset_file, contigs_table, input_vcf, output_vcf, batch_size: int, num_workers: int, chunk_size: int, num_spectrum_iterations: int,
spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False,
segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)):
print("Loading artifact model and test dataset")
base_model = load_base_model(base_model_path)
contig_index_to_name_map = {}
with open(contigs_table) as file:
while line := file.readline().strip():
contig, index = line.split()
contig_index_to_name_map[int(index)] = contig

artifact_model, artifact_log_priors, artifact_spectra_state_dict = load_artifact_model(saved_artifact_model_path)
base_model, artifact_model, artifact_log_priors, artifact_spectra_state_dict = \
load_base_model_and_artifact_model(saved_artifact_model_path)

posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features)
posterior_data_loader = make_posterior_data_loader(test_dataset_file, input_vcf, contig_index_to_name_map,
base_model, artifact_model, batch_size, num_workers=num_workers, chunk_size=chunk_size, segmentation=segmentation, normal_segmentation=normal_segmentation)
Expand Down
2 changes: 1 addition & 1 deletion permutect/tools/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main_without_parsing(args):
summary_writer.add_figure("Artifact AF Spectra", art_spectra_fig)

summary_writer.close()
model.save(getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra)
model.save_with_base_model(base_model, getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra)


def main():
Expand Down
4 changes: 0 additions & 4 deletions scripts/permutect.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/15/pl
workflow Permutect {
input {
File permutect_model
File base_model

File? intervals
File? masks
Expand Down Expand Up @@ -100,7 +99,6 @@ workflow Permutect {
mutect2_vcf = IndexAfterSplitting.vcf,
mutect2_vcf_idx = IndexAfterSplitting.vcf_index,
permutect_model = permutect_model,
base_model = base_model,
test_dataset = select_first([Mutect2.m3_dataset]),
contigs_table = Mutect2.permutect_contigs_table,
maf_segments = Mutect2.maf_segments,
Expand Down Expand Up @@ -133,7 +131,6 @@ workflow Permutect {
task PermutectFiltering {
input {
File permutect_model
File base_model
File test_dataset
File contigs_table
File mutect2_vcf
Expand Down Expand Up @@ -167,7 +164,6 @@ task PermutectFiltering {

filter_variants --input ~{mutect2_vcf} --test_dataset ~{test_dataset} \
--permutect_model ~{permutect_model} \
--base_model ~{base_model} \
--contigs_table ~{contigs_table} \
--output permutect-filtered.vcf \
--tensorboard_dir tensorboard \
Expand Down