diff --git a/permutect/architecture/artifact_model.py b/permutect/architecture/artifact_model.py index fc969412..08b2cb50 100644 --- a/permutect/architecture/artifact_model.py +++ b/permutect/architecture/artifact_model.py @@ -15,7 +15,7 @@ from itertools import chain from matplotlib import pyplot as plt -from permutect.architecture.base_model import calculate_batch_weights +from permutect.architecture.base_model import calculate_batch_weights, BaseModel, base_model_from_saved_dict from permutect.architecture.mlp import MLP from permutect.architecture.monotonic import MonoDense from permutect.data.base_datum import ArtifactBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT @@ -374,27 +374,44 @@ def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, val # done collecting data - def save(self, path, artifact_log_priors, artifact_spectra): - torch.save({ - constants.STATE_DICT_NAME: self.state_dict(), - constants.NUM_BASE_FEATURES_NAME: self.num_base_features, - constants.NUM_REF_ALT_FEATURES_NAME: self.num_ref_alt_features, - constants.HYPERPARAMS_NAME: self.params, - constants.ARTIFACT_LOG_PRIORS_NAME: artifact_log_priors, - constants.ARTIFACT_SPECTRA_STATE_DICT_NAME: artifact_spectra.state_dict() - }, path) + def make_dict_for_saving(self, artifact_log_priors, artifact_spectra, prefix: str = "artifact"): + return {(prefix + constants.STATE_DICT_NAME): self.state_dict(), + (prefix + constants.NUM_BASE_FEATURES_NAME): self.num_base_features, + (prefix + constants.NUM_REF_ALT_FEATURES_NAME): self.num_ref_alt_features, + (prefix + constants.HYPERPARAMS_NAME): self.params, + (prefix + constants.ARTIFACT_LOG_PRIORS_NAME): artifact_log_priors, + (prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME): artifact_spectra.state_dict()} + def save(self, path, artifact_log_priors, artifact_spectra, prefix: str = "artifact"): + torch.save(self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix), path) -# log artifact priors and artifact spectra may be None -def load_artifact_model(path) -> ArtifactModel: - saved = torch.load(path) - model_params = saved[constants.HYPERPARAMS_NAME] - num_base_features = saved[constants.NUM_BASE_FEATURES_NAME] - num_ref_alt_features = saved[constants.NUM_REF_ALT_FEATURES_NAME] + def save_with_base_model(self, base_model: BaseModel, path, artifact_log_priors, artifact_spectra): + artifact_dict = self.make_dict_for_saving(artifact_log_priors, artifact_spectra, prefix="artifact") + base_dict = base_model.make_dict_for_saving(prefix="base") + torch.save({**artifact_dict, **base_dict}, path) + + +def artifact_model_from_saved_dict(saved, prefix: str = "artifact"): + model_params = saved[prefix + constants.HYPERPARAMS_NAME] + num_base_features = saved[prefix + constants.NUM_BASE_FEATURES_NAME] + num_ref_alt_features = saved[prefix + constants.NUM_REF_ALT_FEATURES_NAME] model = ArtifactModel(model_params, num_base_features, num_ref_alt_features) - model.load_state_dict(saved[constants.STATE_DICT_NAME]) + model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME]) - artifact_log_priors = saved[constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None - artifact_spectra_state_dict = saved[constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] #possibly None + artifact_log_priors = saved[prefix + constants.ARTIFACT_LOG_PRIORS_NAME] # possibly None + artifact_spectra_state_dict = saved[prefix + constants.ARTIFACT_SPECTRA_STATE_DICT_NAME] # possibly None return model, artifact_log_priors, artifact_spectra_state_dict + +# log artifact priors and artifact spectra may be None +def load_artifact_model(path, prefix: str = "artifact") -> ArtifactModel: + saved = torch.load(path) + return artifact_model_from_saved_dict(saved, prefix) + + +def load_base_model_and_artifact_model(path) -> ArtifactModel: + saved = torch.load(path) + base_model = base_model_from_saved_dict(saved, prefix="base") + artifact_model, artifact_log_priors, artifact_spectra = artifact_model_from_saved_dict(saved, prefix="artifact") + return base_model, artifact_model, artifact_log_priors, artifact_spectra + diff --git a/permutect/architecture/base_model.py b/permutect/architecture/base_model.py index ec2093cc..2d5f8739 100644 --- a/permutect/architecture/base_model.py +++ b/permutect/architecture/base_model.py @@ -184,26 +184,26 @@ def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) - return result_ve, ref_seq_embeddings_ve # ref seq embeddings are useful later + def make_dict_for_saving(self, prefix: str = ""): + return {(prefix + constants.STATE_DICT_NAME): self.state_dict(), + (prefix + constants.HYPERPARAMS_NAME): self._params, + (prefix + constants.NUM_READ_FEATURES_NAME): self.read_embedding.input_dimension(), + (prefix + constants.NUM_INFO_FEATURES_NAME): self.info_embedding.input_dimension(), + (prefix + constants.REF_SEQUENCE_LENGTH_NAME): self.ref_sequence_length()} + def save(self, path): - torch.save({ - constants.STATE_DICT_NAME: self.state_dict(), - constants.HYPERPARAMS_NAME: self._params, - constants.NUM_READ_FEATURES_NAME: self.read_embedding.input_dimension(), - constants.NUM_INFO_FEATURES_NAME: self.info_embedding.input_dimension(), - constants.REF_SEQUENCE_LENGTH_NAME: self.ref_sequence_length() - }, path) + torch.save(self.make_dict_for_saving(), path) -def load_base_model(path, device: torch.device = utils.gpu_if_available()) -> BaseModel: - saved = torch.load(path) - hyperparams = saved[constants.HYPERPARAMS_NAME] - num_read_features = saved[constants.NUM_READ_FEATURES_NAME] - num_info_features = saved[constants.NUM_INFO_FEATURES_NAME] - ref_sequence_length = saved[constants.REF_SEQUENCE_LENGTH_NAME] +def base_model_from_saved_dict(saved, prefix: str = "", device: torch.device = utils.gpu_if_available()): + hyperparams = saved[prefix + constants.HYPERPARAMS_NAME] + num_read_features = saved[prefix + constants.NUM_READ_FEATURES_NAME] + num_info_features = saved[prefix + constants.NUM_INFO_FEATURES_NAME] + ref_sequence_length = saved[prefix + constants.REF_SEQUENCE_LENGTH_NAME] model = BaseModel(hyperparams, num_read_features=num_read_features, num_info_features=num_info_features, ref_sequence_length=ref_sequence_length, device=device) - model.load_state_dict(saved[constants.STATE_DICT_NAME]) + model.load_state_dict(saved[prefix + constants.STATE_DICT_NAME]) # in case the state dict had the wrong dtype for the device we're on now eg base model was pretrained on GPU # and we're now on CPU @@ -212,6 +212,11 @@ def load_base_model(path, device: torch.device = utils.gpu_if_available()) -> Ba return model +def load_base_model(path, prefix: str = "", device: torch.device = utils.gpu_if_available()) -> BaseModel: + saved = torch.load(path) + return base_model_from_saved_dict(saved, prefix, device) + + # outputs a 1D tensor of losses over the batch. We assume it needs the representations of the batch data from the base # model. We nonetheless also use the model as an input because there are some learning strategies that involve # computing representations of a modified batch. diff --git a/permutect/architecture/gradient_reversal/__init__.py b/permutect/architecture/gradient_reversal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/permutect/architecture/gradient_reversal/functional.py b/permutect/architecture/gradient_reversal/functional.py new file mode 100644 index 00000000..195a3d76 --- /dev/null +++ b/permutect/architecture/gradient_reversal/functional.py @@ -0,0 +1,25 @@ +from typing import Any + +from torch.autograd import Function + + +class GradientReversal(Function): + @staticmethod + def jvp(ctx: Any, *grad_inputs: Any) -> Any: + pass + + @staticmethod + def forward(ctx, x, alpha): + ctx.save_for_backward(x, alpha) + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + _, alpha = ctx.saved_tensors + if ctx.needs_input_grad[0]: + grad_input = - alpha * grad_output + return grad_input, None + + +revgrad = GradientReversal.apply diff --git a/permutect/architecture/gradient_reversal/module.py b/permutect/architecture/gradient_reversal/module.py new file mode 100644 index 00000000..47c70a95 --- /dev/null +++ b/permutect/architecture/gradient_reversal/module.py @@ -0,0 +1,12 @@ +from .functional import revgrad +import torch +from torch import nn + + +class GradientReversal(nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = torch.tensor(alpha, requires_grad=False) + + def forward(self, x): + return revgrad(x, self.alpha) diff --git a/permutect/test/tools/test_filter_variants.py b/permutect/test/tools/test_filter_variants.py index 29ed73dc..d37fe23d 100644 --- a/permutect/test/tools/test_filter_variants.py +++ b/permutect/test/tools/test_filter_variants.py @@ -7,7 +7,6 @@ def test_filtering_on_dream1_chr20(): # Inputs - base_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/base-model.pt' artifact_model = '/Users/davidben/mutect3/permutect/integration-tests/singular-10-Mb/artifact-model.pt' mutect2_vcf = '/Users/davidben/mutect3/permutect/integration-tests/dream1-chr20/mutect2_chr20.vcf' @@ -22,7 +21,6 @@ def test_filtering_on_dream1_chr20(): filtering_args = Namespace() setattr(filtering_args, constants.INPUT_NAME, mutect2_vcf) setattr(filtering_args, constants.TEST_DATASET_NAME, filtering_dataset) - setattr(filtering_args, constants.BASE_MODEL_NAME, base_model) setattr(filtering_args, constants.M3_MODEL_NAME, artifact_model) setattr(filtering_args, constants.OUTPUT_NAME, permutect_vcf.name) setattr(filtering_args, constants.TENSORBOARD_DIR_NAME, tensorboard_dir.name) diff --git a/permutect/tools/filter_variants.py b/permutect/tools/filter_variants.py index c621f0fd..4d106d5e 100644 --- a/permutect/tools/filter_variants.py +++ b/permutect/tools/filter_variants.py @@ -11,7 +11,7 @@ from tqdm.autonotebook import tqdm from permutect import constants -from permutect.architecture.artifact_model import ArtifactModel, load_artifact_model +from permutect.architecture.artifact_model import ArtifactModel, load_artifact_model, load_base_model_and_artifact_model from permutect.architecture.posterior_model import PosteriorModel from permutect.architecture.base_model import BaseModel, load_base_model from permutect.data import base_dataset, plain_text_data, base_datum @@ -73,8 +73,7 @@ def parse_arguments(): parser.add_argument('--' + constants.INPUT_NAME, required=True, help='unfiltered input Mutect2 VCF') parser.add_argument('--' + constants.TEST_DATASET_NAME, required=True, help='plain text dataset file corresponding to variants in input VCF') - parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Mutect3 artifact model from train_model.py') - parser.add_argument('--' + constants.BASE_MODEL_NAME, type=str, help='Base model from train_base_model.py') + parser.add_argument('--' + constants.M3_MODEL_NAME, required=True, help='trained Permutect model from train_model.py') parser.add_argument('--' + constants.CONTIGS_TABLE_NAME, required=True, help='table of contig names vs integer indices') parser.add_argument('--' + constants.OUTPUT_NAME, required=True, help='path to output filtered VCF') parser.add_argument('--' + constants.TENSORBOARD_DIR_NAME, type=str, default='tensorboard', required=False, help='path to output tensorboard') @@ -132,7 +131,6 @@ def get_segmentation(segments_file) -> defaultdict: def main_without_parsing(args): make_filtered_vcf(saved_artifact_model_path=getattr(args, constants.M3_MODEL_NAME), - base_model_path=getattr(args, constants.BASE_MODEL_NAME), initial_log_variant_prior=getattr(args, constants.INITIAL_LOG_VARIANT_PRIOR_NAME), initial_log_artifact_prior=getattr(args, constants.INITIAL_LOG_ARTIFACT_PRIOR_NAME), test_dataset_file=getattr(args, constants.TEST_DATASET_NAME), @@ -152,19 +150,20 @@ def main_without_parsing(args): normal_segmentation=get_segmentation(getattr(args, constants.NORMAL_MAF_SEGMENTS_NAME))) -def make_filtered_vcf(saved_artifact_model_path, base_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float, +def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float, test_dataset_file, contigs_table, input_vcf, output_vcf, batch_size: int, num_workers: int, chunk_size: int, num_spectrum_iterations: int, spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False, segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)): print("Loading artifact model and test dataset") - base_model = load_base_model(base_model_path) contig_index_to_name_map = {} with open(contigs_table) as file: while line := file.readline().strip(): contig, index = line.split() contig_index_to_name_map[int(index)] = contig - artifact_model, artifact_log_priors, artifact_spectra_state_dict = load_artifact_model(saved_artifact_model_path) + base_model, artifact_model, artifact_log_priors, artifact_spectra_state_dict = \ + load_base_model_and_artifact_model(saved_artifact_model_path) + posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features) posterior_data_loader = make_posterior_data_loader(test_dataset_file, input_vcf, contig_index_to_name_map, base_model, artifact_model, batch_size, num_workers=num_workers, chunk_size=chunk_size, segmentation=segmentation, normal_segmentation=normal_segmentation) diff --git a/permutect/tools/train_model.py b/permutect/tools/train_model.py index 9269f12c..e7d75273 100644 --- a/permutect/tools/train_model.py +++ b/permutect/tools/train_model.py @@ -109,7 +109,7 @@ def main_without_parsing(args): summary_writer.add_figure("Artifact AF Spectra", art_spectra_fig) summary_writer.close() - model.save(getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra) + model.save_with_base_model(base_model, getattr(args, constants.OUTPUT_NAME), artifact_log_priors, artifact_spectra) def main(): diff --git a/scripts/permutect.wdl b/scripts/permutect.wdl index d2a09c23..b41c909b 100644 --- a/scripts/permutect.wdl +++ b/scripts/permutect.wdl @@ -5,7 +5,6 @@ import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/15/pl workflow Permutect { input { File permutect_model - File base_model File? intervals File? masks @@ -100,7 +99,6 @@ workflow Permutect { mutect2_vcf = IndexAfterSplitting.vcf, mutect2_vcf_idx = IndexAfterSplitting.vcf_index, permutect_model = permutect_model, - base_model = base_model, test_dataset = select_first([Mutect2.m3_dataset]), contigs_table = Mutect2.permutect_contigs_table, maf_segments = Mutect2.maf_segments, @@ -133,7 +131,6 @@ workflow Permutect { task PermutectFiltering { input { File permutect_model - File base_model File test_dataset File contigs_table File mutect2_vcf @@ -167,7 +164,6 @@ task PermutectFiltering { filter_variants --input ~{mutect2_vcf} --test_dataset ~{test_dataset} \ --permutect_model ~{permutect_model} \ - --base_model ~{base_model} \ --contigs_table ~{contigs_table} \ --output permutect-filtered.vcf \ --tensorboard_dir tensorboard \