Skip to content

Commit

Permalink
Option to broaden germline allele fraction range for linseq evaluation (
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin authored Nov 14, 2024
1 parent 6a8919a commit b7d66d0
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
17 changes: 10 additions & 7 deletions permutect/architecture/posterior_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

# TODO: write unit test asserting that this comes out to zero when counts are zero
# given germline, the probability of these particular reads being alt
def germline_log_likelihood(afs, mafs, alt_counts, depths):
def germline_log_likelihood(afs, mafs, alt_counts, depths, het_beta=None):
HOM_ALPHA, HOM_BETA = torch.tensor([98.0], device=depths.device), torch.tensor([2.0], device=depths.device)

HET_ALPHA, HET_BETA = torch.tensor([het_beta], device=depths.device), torch.tensor([het_beta], device=depths.device)
het_probs = 2 * afs * (1 - afs)
hom_probs = afs * afs
het_proportion = het_probs / (het_probs + hom_probs)
Expand All @@ -37,8 +37,10 @@ def germline_log_likelihood(afs, mafs, alt_counts, depths):

combinatorial_term = torch.lgamma(depths + 1) - torch.lgamma(alt_counts + 1) - torch.lgamma(ref_counts + 1)
# the following should both be 1D tensors of length batch size
alt_minor_ll = combinatorial_term + log_half_het_prop + alt_counts * log_mafs + ref_counts * log_1m_mafs
alt_major_ll = combinatorial_term + log_half_het_prop + ref_counts * log_mafs + alt_counts * log_1m_mafs
alt_minor_binomial = combinatorial_term + alt_counts * log_mafs + ref_counts * log_1m_mafs
alt_major_binomial = combinatorial_term + ref_counts * log_mafs + alt_counts * log_1m_mafs
alt_minor_ll = log_half_het_prop + (alt_minor_binomial if het_beta is None else utils.beta_binomial(depths, alt_counts, HET_ALPHA, HET_BETA))
alt_major_ll = log_half_het_prop + (alt_major_binomial if het_beta is None else utils.beta_binomial(depths, alt_counts, HET_ALPHA, HET_BETA))
hom_ll = torch.log(hom_proportion) + utils.beta_binomial(depths, alt_counts, HOM_ALPHA, HOM_BETA)

return torch.logsumexp(torch.vstack((alt_minor_ll, alt_major_ll, hom_ll)), dim=0)
Expand Down Expand Up @@ -72,13 +74,14 @@ class PosteriorModel(torch.nn.Module):
"""
"""
def __init__(self, variant_log_prior: float, artifact_log_prior: float, num_base_features: int, no_germline_mode: bool = False, device=utils.gpu_if_available()):
def __init__(self, variant_log_prior: float, artifact_log_prior: float, num_base_features: int, no_germline_mode: bool = False, device=utils.gpu_if_available(), het_beta: float = None):
super(PosteriorModel, self).__init__()

self._device = device
self._dtype = DEFAULT_GPU_FLOAT if device != torch.device("cpu") else DEFAULT_CPU_FLOAT
self.no_germline_mode = no_germline_mode
self.num_base_features = num_base_features
self.het_beta = het_beta

# TODO introduce parameters class so that num_components is not hard-coded
self.somatic_spectrum = SomaticSpectrum(num_components=5)
Expand Down Expand Up @@ -180,11 +183,11 @@ def log_posterior_and_ingredients(self, batch: PosteriorBatch) -> torch.Tensor:
torch.logical_not(no_alt_in_normal_mask) * self.normal_artifact_spectra.forward(types, normal_depths, normal_alt_counts)

afs = batch.get_allele_frequencies()
spectra_log_likelihoods[:, Call.GERMLINE] = germline_log_likelihood(afs, batch.get_mafs(), alt_counts, depths) - flat_prior_spectra_log_likelihoods
spectra_log_likelihoods[:, Call.GERMLINE] = germline_log_likelihood(afs, batch.get_mafs(), alt_counts, depths, self.het_beta) - flat_prior_spectra_log_likelihoods

# it is correct not to subtract the flat prior likelihood from the normal term because this is an absolute likelihood, not
# relative to seq error as the M2 TLOD is defined
normal_log_likelihoods[:, Call.GERMLINE] = germline_log_likelihood(afs, batch.get_normal_mafs(), normal_alt_counts, normal_depths)
normal_log_likelihoods[:, Call.GERMLINE] = germline_log_likelihood(afs, batch.get_normal_mafs(), normal_alt_counts, normal_depths, self.het_beta)

log_posteriors = log_priors + spectra_log_likelihoods + normal_log_likelihoods
log_posteriors[:, Call.ARTIFACT] += batch.get_artifact_logits()
Expand Down
1 change: 1 addition & 0 deletions permutect/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
NORMAL_MAF_SEGMENTS_NAME = 'normal_maf_segments'
GERMLINE_MODE_NAME = 'germline_mode'
NO_GERMLINE_MODE_NAME = 'no_germline_mode'
HET_BETA_NAME = 'het_beta'

BASE_MODEL_NAME = 'base_model'
M3_MODEL_NAME = 'permutect_model'
Expand Down
11 changes: 7 additions & 4 deletions permutect/metrics/evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,22 @@ def make_plots(self, summary_writer: SummaryWriter, given_thresholds=None, sens_
metric.plot_roc_curves_by_count(var_type, roc_by_cnt_axes[row_idx, var_type], given_threshold, sens_prec)
# done collecting stats for all loaders and filling in subplots

nonart_label = "sensitivity" if sens_prec else "non-artifact accuracy"
art_label = "precision" if sens_prec else "artifact accuracy"

variation_types = [var_type.name for var_type in Variation]
row_names = [epoch_type.name for epoch_type in self.metrics.keys()]
plotting.tidy_subplots(acc_vs_cnt_fig, acc_vs_cnt_axes, x_label="alt count", y_label="accuracy", row_labels=row_names, column_labels=variation_types)
plotting.tidy_subplots(roc_fig, roc_axes, x_label="non-artifact accuracy", y_label="artifact accuracy", row_labels=row_names, column_labels=variation_types)
plotting.tidy_subplots(roc_by_cnt_fig, roc_by_cnt_axes, x_label="non-artifact accuracy", y_label="artifact accuracy", row_labels=row_names, column_labels=variation_types)
plotting.tidy_subplots(roc_fig, roc_axes, x_label=nonart_label, y_label=art_label, row_labels=row_names, column_labels=variation_types)
plotting.tidy_subplots(roc_by_cnt_fig, roc_by_cnt_axes, x_label=nonart_label, y_label=art_label, row_labels=row_names, column_labels=variation_types)
plotting.tidy_subplots(cal_fig, cal_axes, x_label="predicted logit", y_label="accuracy", row_labels=row_names, column_labels=variation_types)
plotting.tidy_subplots(cal_fig_all_counts, cal_axes_all_counts, x_label="predicted logit", y_label="accuracy", row_labels=row_names, column_labels=variation_types)

summary_writer.add_figure("accuracy by alt count", acc_vs_cnt_fig, global_step=epoch)
summary_writer.add_figure(" accuracy by logit output by count", cal_fig, global_step=epoch)
summary_writer.add_figure(" accuracy by logit output", cal_fig_all_counts, global_step=epoch)
summary_writer.add_figure(" variant accuracy vs artifact accuracy curve", roc_fig, global_step=epoch)
summary_writer.add_figure(" variant accuracy vs artifact accuracy curves by alt count", roc_by_cnt_fig, global_step=epoch)
summary_writer.add_figure("sensitivity vs precision" if sens_prec else "variant accuracy vs artifact accuracy", roc_fig, global_step=epoch)
summary_writer.add_figure("sensitivity vs precision by alt count" if sens_prec else "variant accuracy vs artifact accuracy by alt count", roc_by_cnt_fig, global_step=epoch)


def sample_indices_for_tensorboard(indices: List[int]):
Expand Down
2 changes: 1 addition & 1 deletion permutect/metrics/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def plot_accuracy_vs_accuracy_roc_on_axis(lists_of_predictions_and_labels, curve
else:
small_dots.append((art_acc, non_art_acc, 'go')) # green circle

simple_plot_on_axis(axis, x_y_lab_tuples, "artifact accuracy", "non-artifact accuracy")
simple_plot_on_axis(axis, x_y_lab_tuples, "precision" if sens_prec else "artifact accuracy", "sensitivity" if sens_prec else "non-artifact accuracy")
for x, y, spec in small_dots:
axis.plot(x, y, spec, markersize=2,label="") # point
for x, y, spec in big_dots:
Expand Down
7 changes: 5 additions & 2 deletions permutect/tools/filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def parse_arguments():
parser.add_argument('--' + constants.GERMLINE_MODE_NAME, action='store_true',
help='flag for genotyping both somatic and somatic variants distinctly but considering both '
'as non-errors (true positives), which affects the posterior threshold set by optimal F1 score')
parser.add_argument('--' + constants.HET_BETA_NAME, type=float, required=False,
help='beta shape parameter for germline spectrum beta binomial if we want to override binomial')

parser.add_argument('--' + constants.NO_GERMLINE_MODE_NAME, action='store_true',
help='flag for not genotyping germline events so that the only possibilities considered are '
Expand Down Expand Up @@ -146,13 +148,14 @@ def main_without_parsing(args):
genomic_span=getattr(args, constants.GENOMIC_SPAN_NAME),
germline_mode=getattr(args, constants.GERMLINE_MODE_NAME),
no_germline_mode=getattr(args, constants.NO_GERMLINE_MODE_NAME),
het_beta=getattr(args, constants.HET_BETA_NAME),
segmentation=get_segmentation(getattr(args, constants.MAF_SEGMENTS_NAME)),
normal_segmentation=get_segmentation(getattr(args, constants.NORMAL_MAF_SEGMENTS_NAME)))


def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: float, initial_log_artifact_prior: float,
test_dataset_file, contigs_table, input_vcf, output_vcf, batch_size: int, num_workers: int, chunk_size: int, num_spectrum_iterations: int,
spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False,
spectrum_learning_rate: float, tensorboard_dir, genomic_span: int, germline_mode: bool = False, no_germline_mode: bool = False, het_beta: float = None,
segmentation=defaultdict(IntervalTree), normal_segmentation=defaultdict(IntervalTree)):
print("Loading artifact model and test dataset")
contig_index_to_name_map = {}
Expand All @@ -165,7 +168,7 @@ def make_filtered_vcf(saved_artifact_model_path, initial_log_variant_prior: floa
base_model, artifact_model, artifact_log_priors, artifact_spectra_state_dict = \
load_base_model_and_artifact_model(saved_artifact_model_path, device=device)

posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features)
posterior_model = PosteriorModel(initial_log_variant_prior, initial_log_artifact_prior, no_germline_mode=no_germline_mode, num_base_features=artifact_model.num_base_features, het_beta=het_beta)
posterior_data_loader = make_posterior_data_loader(test_dataset_file, input_vcf, contig_index_to_name_map,
base_model, artifact_model, batch_size, num_workers=num_workers, chunk_size=chunk_size, segmentation=segmentation, normal_segmentation=normal_segmentation)

Expand Down

0 comments on commit b7d66d0

Please sign in to comment.