Skip to content

Commit

Permalink
Several new WDLs and a gated MLP bug fix (#160)
Browse files Browse the repository at this point in the history
* WDL combining gathering plain text data and preprocessing

* WDL to run generate train and test datasets with M2, make UDA dataset and train, run Permutect filtering

* bug fix for NaNs in base model output features introduced by the recent mixed-count base batch PR
  • Loading branch information
davidbenjamin authored Nov 22, 2024
1 parent 1f0b980 commit dc8aa47
Show file tree
Hide file tree
Showing 9 changed files with 472 additions and 28 deletions.
4 changes: 2 additions & 2 deletions permutect/architecture/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self, d_z: int):

# regularizer / sort of imputed value for when there are no ref counts
self.ref_regularizer = nn.Parameter(0.1 * torch.ones(d_z // 2))
self.regularizer_weight = nn.Parameter(torch.tensor(0.1))
self.regularizer_weight_pre_exp = nn.Parameter(torch.log(torch.tensor(0.1)))

def forward(self, zref_rd: torch.Tensor, zalt_rd: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor):

Expand All @@ -223,7 +223,7 @@ def forward(self, zref_rd: torch.Tensor, zalt_rd: torch.Tensor, ref_counts: torc
z2_alt_rd = self.norm(z2_alt_rd)

# these are means by variant -- need repeat_interleave to make them by-read
ref_mean_field_vd = utils.means_over_rows_with_regularizer(z2_ref_rd, ref_counts, self.ref_regularizer, self.regularizer_weight)
ref_mean_field_vd = utils.means_over_rows_with_regularizer(z2_ref_rd, ref_counts, self.ref_regularizer, torch.exp(self.regularizer_weight_pre_exp) + 0.25)
alt_mean_field_vd = utils.means_over_rows(z2_alt_rd, alt_counts)

ref_mean_field_on_ref_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=ref_counts)
Expand Down
5 changes: 4 additions & 1 deletion permutect/data/base_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# base strings longer than this when encoding data
MAX_NUM_BASES_FOR_ENCODING = 13

MAX_FLOAT_16 = torch.finfo(torch.float16).max
MIN_FLOAT_16 = torch.finfo(torch.float16).min


def make_1d_sequence_tensor(sequence_string: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -640,7 +643,7 @@ class ArtifactDatum:
def __init__(self, base_datum: BaseDatum, representation: Tensor):
# Note: if changing any of the data fields below, make sure to modify the size_in_bytes() method below accordingly!
assert representation.dim() == 1
self.representation = representation
self.representation = torch.clamp(representation, MIN_FLOAT_16, MAX_FLOAT_16)
self.other_stuff = ArtifactDatum1DStuff(base_datum.get_other_stuff_1d())
self.set_dtype(np.float16)

Expand Down
4 changes: 0 additions & 4 deletions permutect/test/tools/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def test_train_model():
assert artifact_log_priors is not None
assert artifact_spectra_state_dict is not None

saved = torch.load(saved_artifact_model, device=device)
assert constants.ARTIFACT_LOG_PRIORS_NAME in saved
assert constants.ARTIFACT_SPECTRA_STATE_DICT_NAME in saved

print(artifact_log_priors)
h = 99

1 change: 1 addition & 0 deletions permutect/tools/prune_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def generated_pruned_data_for_fold(art_threshold: float, nonart_threshold: float
representation, _ = base_model.calculate_representations(base_batch)

artifact_batch = ArtifactBatch([ArtifactDatum(rs, rep) for rs, rep in zip(base_batch.original_list(), representation.detach())])

art_logits, _, _ = artifact_model.forward(artifact_batch)
art_probs = torch.sigmoid(art_logits.detach())
art_label_mask = (base_batch.labels > 0.5)
Expand Down
3 changes: 1 addition & 2 deletions permutect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,9 @@ def means_over_rows(input_tensor: torch.Tensor, counts: torch.IntTensor, keepdim
# same but include a regularizer in case of zeros in the counts vector
# regularizer has the dimension of one row of the input tensor
def means_over_rows_with_regularizer(input_tensor: torch.Tensor, counts: torch.IntTensor, regularizer, regularizer_weight, keepdim: bool = False):
# TODO: left off right here
extra_dims = (1,) * (input_tensor.dim() - 1)

regularized_sums = sums_over_rows(input_tensor, counts) + regularizer[None, :]
regularized_sums = sums_over_rows(input_tensor, counts) + (regularizer_weight * regularizer)[None, :]
regularized_counts = counts + regularizer_weight
result = regularized_sums / regularized_counts.view(-1, *extra_dims)

Expand Down
300 changes: 300 additions & 0 deletions scripts/call_variants_with_uda.wdl
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
version 1.0

# run Mutect2 to get both training AND test datasets. The training dataset is preprocessed and combined with
# high-quality labeled data to make a UDA dataset, then used to train an artifact model. The test dataset is used
# for the posterior model and filtering.
# note that the artifact model can be trained before the Mutect2 workflow runs FilterMutectCalls
import "https://api.firecloud.org/ga4gh/v1/tools/davidben:mutect2/versions/18/plain-WDL/descriptor" as m2
import "https://api.firecloud.org/ga4gh/v1/tools/davidben:permutect-uda-dataset/versions/3/plain-WDL/descriptor" as uda
import "https://api.firecloud.org/ga4gh/v1/tools/davidben:permutect-train-artifact-model/versions/12/plain-WDL/descriptor" as training
import "https://api.firecloud.org/ga4gh/v1/tools/davidben:permutect-call-variants/versions/18/plain-WDL/descriptor" as calling

workflow CallVariantsWithUDA {
input {
# basic inputs for Mutect2
File? intervals
File? masked_intervals
File ref_fasta
File ref_fai
File ref_dict
File primary_bam
File primary_bai
File? control_bam
File? control_bai
File? gnomad
File? gnomad_idx
String? m2_extra_args
File? dragstr_model
Boolean make_bamout = false
Boolean compress_vcfs = false

# Mutect2 filtering
Boolean skip_m2_filtering
File? variants_for_contamination
File? variants_for_contamination_idx
File? realignment_index_bundle
String? realignment_extra_args
Boolean? run_orientation_bias_mixture_model_filter

# preprocessing arguments
Int chunk_size

# training arguments for both artifact model and posterior model
Boolean use_gpu = true
Int batch_size
Int inference_batch_size
Int num_workers
Int? gpu_count
Int? training_mem

# UDA training arguments
File base_model
File source_train_tar
String source_edit_type = "keep_everything"
String target_edit_type = "unlabel_everything"
Int num_epochs
Int num_calibration_epochs
Float dropout_p
Array[Int] aggregation_layers
Array[Int] calibration_layers
String? training_extra_args
Boolean learn_artifact_spectra
Float? genomic_span

# Permutect filtering / posterior model
File? test_dataset_truth_vcf # used for evaluation
File? test_dataset_truth_vcf_idx
Int? num_spectrum_iterations
Float? spectrum_learning_rate
String? permutect_filtering_extra_args
String bcftools_docker = "us.gcr.io/broad-dsde-methods/davidben/bcftools"
File? obscene_hack_leave_unset


# runtime
String gatk_docker
String permutect_docker
File? gatk_override
String basic_bash_docker = "ubuntu:16.04"
Int scatter_count
Int preemptible = 2
Int max_retries = 1
Int small_task_cpu = 2
Int small_task_mem = 4
Int small_task_disk = 100
Int boot_disk_size = 12
Int learn_read_orientation_mem = 8000
Int filter_alignment_artifacts_mem = 9000
String? gcs_project_for_requester_pays

# Use as a last resort to increase the disk given to every task in case of ill behaving data
Int emergency_extra_disk = 0
}

# note: we make both training and test datasets
# note: for speed we may skip filtering in order to begin UDA artifact model training immediately
# the only M2 filtering we may need is contamination, and that may be skipped
call m2.Mutect2 {
input:
intervals = intervals,
masked_intervals = masked_intervals,
ref_fasta = ref_fasta,
ref_fai = ref_fai,
ref_dict = ref_dict,
tumor_reads = primary_bam,
tumor_reads_index = primary_bai,
normal_reads = control_bam,
normal_reads_index = control_bai,
gnomad = gnomad,
gnomad_idx = gnomad_idx,
variants_for_contamination = variants_for_contamination,
variants_for_contamination_idx = variants_for_contamination_idx,
realignment_index_bundle = realignment_index_bundle,
realignment_extra_args = realignment_extra_args,
run_orientation_bias_mixture_model_filter = run_orientation_bias_mixture_model_filter,
m2_extra_args = m2_extra_args,
dragstr_model = dragstr_model,
make_bamout = make_bamout,
make_permutect_training_dataset = true,
make_permutect_test_dataset = true,
permutect_test_dataset_truth_vcf = test_dataset_truth_vcf,
permutect_test_dataset_truth_vcf_idx = test_dataset_truth_vcf_idx,
skip_filtering = skip_m2_filtering,
gatk_docker = gatk_docker,
gatk_override = gatk_override,
scatter_count = scatter_count,
preemptible = preemptible,
max_retries = max_retries,
small_task_cpu = small_task_cpu,
small_task_mem = small_task_mem,
small_task_disk = small_task_disk,
boot_disk_size = boot_disk_size,
gcs_project_for_requester_pays = gcs_project_for_requester_pays,
emergency_extra_disk = emergency_extra_disk
}

# preprocess the training data from Mutect2
call Preprocess {
input:
training_dataset = select_first([Mutect2.permutect_training_dataset]),
chunk_size = chunk_size,
permutect_docker = permutect_docker
}

# combine the source_tar and preprocessed training data into a UDA dataset
call uda.PermutectUDADataset {
input:
source_train_tar = source_train_tar,
target_train_tar = Preprocess.train_tar,
source_edit_type = source_edit_type,
target_edit_type = target_edit_type,
chunk_size = chunk_size,
permutect_docker = permutect_docker,
preemptible = 0,
max_retries = 0
}

# train an artifact model on the UDA dataset
call training.TrainPermutect {
input:
train_tar = PermutectUDADataset.uda_train_tar,
base_model = base_model,
num_epochs = num_epochs,
num_calibration_epochs = num_calibration_epochs,
batch_size = batch_size,
inference_batch_size = inference_batch_size,
num_workers = num_workers,
mem = training_mem,
use_gpu = use_gpu,
gpu_count = gpu_count,
dropout_p = dropout_p,
aggregation_layers = aggregation_layers,
calibration_layers = calibration_layers,
train_m3_extra_args = training_extra_args,
learn_artifact_spectra = learn_artifact_spectra,
genomic_span = genomic_span,
permutect_docker = permutect_docker,
preemptible = 0,
max_retries = 0
}

# we already ran M2 so we don't need the entire calling workflow, just the post-M2 parts of it
call calling.SplitMultiallelics {
input:
input_vcf = Mutect2.output_vcf,
input_vcf_idx = Mutect2.output_vcf_idx,
ref_fasta = ref_fasta,
ref_fai = ref_fai,
ref_dict = ref_dict,
bcftools_docker = bcftools_docker
}

call calling.IndexVCF as IndexAfterSplitting {
input:
unindexed_vcf = SplitMultiallelics.output_vcf,
gatk_docker = gatk_docker
}

if (use_gpu) {
call calling.PermutectFilteringGPU {
input:
mutect2_vcf = IndexAfterSplitting.vcf,
mutect2_vcf_idx = IndexAfterSplitting.vcf_index,
permutect_model = TrainPermutect.artifact_model,
test_dataset = select_first([Mutect2.permutect_test_dataset]),
contigs_table = Mutect2.permutect_contigs_table,
maf_segments = Mutect2.maf_segments,
mutect_stats = Mutect2.mutect_stats,
batch_size = batch_size,
num_workers = num_workers,
gpu_count = gpu_count,
num_spectrum_iterations = num_spectrum_iterations,
spectrum_learning_rate = spectrum_learning_rate,
chunk_size = chunk_size,
permutect_filtering_extra_args = permutect_filtering_extra_args,
permutect_docker = permutect_docker,
}
}

if (!use_gpu) {
call calling.PermutectFilteringCPU {
input:
mutect2_vcf = IndexAfterSplitting.vcf,
mutect2_vcf_idx = IndexAfterSplitting.vcf_index,
permutect_model = TrainPermutect.artifact_model,
test_dataset = select_first([Mutect2.permutect_test_dataset]),
contigs_table = Mutect2.permutect_contigs_table,
maf_segments = Mutect2.maf_segments,
mutect_stats = Mutect2.mutect_stats,
batch_size = batch_size,
num_workers = num_workers,
num_spectrum_iterations = num_spectrum_iterations,
spectrum_learning_rate = spectrum_learning_rate,
chunk_size = chunk_size,
permutect_filtering_extra_args = permutect_filtering_extra_args,
permutect_docker = permutect_docker,
}
}

call calling.IndexVCF as IndexAfterFiltering {
input:
unindexed_vcf = select_first([PermutectFilteringGPU.output_vcf, PermutectFilteringCPU.output_vcf]),
gatk_docker = gatk_docker
}

output {
File? bamout = Mutect2.bamout
File? bamout_index = Mutect2.bamout_index
File mutect_stats = Mutect2.mutect_stats
File permutect_contigs_table = Mutect2.permutect_contigs_table
File permutect_read_groups_table = Mutect2.permutect_read_groups_table
File train_tar = Preprocess.train_tar
File training_tensorboard_tar = TrainPermutect.training_tensorboard_tar
File output_vcf = IndexAfterFiltering.vcf
File output_vcf_idx = IndexAfterFiltering.vcf_index
File calling_tensorboard_tar = select_first([PermutectFilteringGPU.tensorboard_report, PermutectFilteringCPU.tensorboard_report])
}

}

task Preprocess {
input {
File training_dataset
Int chunk_size
Int? source_label

String permutect_docker
Int? preemptible
Int? max_retries
Int? disk_space
Int? cpu
Int? mem
Boolean use_ssd = true
}

# Mem is in units of GB but our command and memory runtime values are in MB
Int machine_mem = if defined(mem) then mem * 1000 else 16000
Int command_mem = machine_mem - 500

command <<<
set -e

preprocess_dataset --training_datasets ~{training_dataset} --chunk_size ~{chunk_size} ~{"--sources " + source_label} --output train.tar
>>>

runtime {
docker: permutect_docker
bootDiskSizeGb: 12
memory: machine_mem + " MB"
disks: "local-disk " + select_first([disk_space, 100]) + if use_ssd then " SSD" else " HDD"
preemptible: select_first([preemptible, 2])
maxRetries: select_first([max_retries, 0])
cpu: select_first([cpu, 1])
}

output {
File train_tar = "train.tar"
}
}
Loading

0 comments on commit dc8aa47

Please sign in to comment.