Skip to content

Commit

Permalink
Tensorboard histograms of artifact logits (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin authored Dec 11, 2024
1 parent 00eda79 commit b757123
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 20 deletions.
15 changes: 9 additions & 6 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
assert epoch_type == Epoch.TRAIN or epoch_type == Epoch.VALID # not doing TEST here
loader = train_loader if epoch_type == Epoch.TRAIN else valid_loader
pbar = tqdm(enumerate(loader), mininterval=60)

batch_cpu: ArtifactBatch
for n, batch_cpu in pbar:
batch = batch_cpu.copy_to(self._device, self._dtype, non_blocking=self._device.type == 'cuda')

Expand All @@ -393,13 +395,14 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
labels = batch_cpu.get_training_labels()
correct = ((pred > 0) == (labels > 0.5)).tolist()

for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, variant, weight in zip(
batch_cpu.get_variant_types().tolist(), pred.tolist(), labels.tolist(), batch_cpu.get_is_labeled_mask().tolist(), correct,
for variant_type, predicted_logit, source, int_label, correct_call, alt_count, variant, weight in zip(
batch_cpu.get_variant_types().tolist(), pred.tolist(), batch.get_sources().tolist(), batch_cpu.get_labels().tolist(), correct,
batch_cpu.get_alt_counts().tolist(), batch_cpu.get_variants(), weights.tolist()):
if is_labeled < 0.5: # we only evaluate labeled data
continue
evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call, alt_count, weight)
if report_worst and not correct_call:
label = Label(int_label)
evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call,
alt_count, weight, source=source)

if (label != Label.UNLABELED) and report_worst and not correct_call:
rounded_count = round_up_to_nearest_three(alt_count)
label_name = Label.ARTIFACT.name if label > 0.5 else Label.VARIANT.name
confidence = abs(predicted_logit)
Expand Down
72 changes: 59 additions & 13 deletions permutect/metrics/evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter

from permutect.data.base_datum import BaseBatch, ArtifactBatch
from permutect.metrics import plotting
from permutect.utils import Variation, Call, Epoch, StreamingAverage
from permutect.utils import Variation, Call, Epoch, StreamingAverage, Label

MAX_COUNT = 18 # counts above this will be truncated
MAX_LOGIT = 15
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(self):
self.acc_vs_logit_all_counts = {
var_type: [StreamingAverage() for _ in range(2 * MAX_LOGIT + 1)] for var_type in Variation}

# indexed by variant type, then call type (artifact vs variant), then count bin
# indexed by variant type, then Label (artifact vs variant), then count bin
self.acc_vs_cnt = {var_type: defaultdict(lambda: [StreamingAverage() for _ in range(NUM_COUNT_BINS)]) for
var_type in Variation}

Expand All @@ -164,18 +165,26 @@ def __init__(self):
# variant type, count -> (predicted logit, actual label)
self.roc_data_by_cnt = {var_type: [[] for _ in range(NUM_COUNT_BINS)] for var_type in Variation}

# Variant is an IntEnum, so variant_type can also be integer
# label is 1 for artifact / error; 0 for non-artifact / true variant
# list of logits for histograms, by variant type, count, label, source
self.logit_histogram_data_vcls = {var_type: [defaultdict(lambda: defaultdict(list)) for _ in range(NUM_COUNT_BINS)] for var_type in Variation}

self.all_sources = set()

# correct_call is boolean -- was the prediction correct?
# the predicted logit is the logit corresponding to the predicted probability that call in question is an artifact / error
def record_call(self, variant_type: Variation, predicted_logit: float, label: float, correct_call, alt_count: int, weight: float = 1.0):
def record_call(self, variant_type: Variation, predicted_logit: float, label: Label, correct_call, alt_count: int, weight: float = 1.0, source: int = 0):
count_bin_index = multiple_of_three_bin_index(min(MAX_COUNT, alt_count))
self.acc_vs_cnt[variant_type][Call.SOMATIC if label < 0.5 else Call.ARTIFACT][count_bin_index].record(correct_call, weight)
self.acc_vs_logit[variant_type][count_bin_index][logit_to_bin(predicted_logit)].record(correct_call, weight)
self.acc_vs_logit_all_counts[variant_type][logit_to_bin(predicted_logit)].record(correct_call, weight)
self.all_sources.add(source)
self.logit_histogram_data_vcls[variant_type][count_bin_index][label][source].append(predicted_logit)

if label != Label.UNLABELED:
self.acc_vs_cnt[variant_type][label][count_bin_index].record(correct_call, weight)
self.acc_vs_logit[variant_type][count_bin_index][logit_to_bin(predicted_logit)].record(correct_call, weight)
self.acc_vs_logit_all_counts[variant_type][logit_to_bin(predicted_logit)].record(correct_call, weight)

self.roc_data[variant_type].append((predicted_logit, label))
self.roc_data_by_cnt[variant_type][count_bin_index].append((predicted_logit, label))
float_label = (1.0 if label == Label.ARTIFACT else 0.0)
self.roc_data[variant_type].append((predicted_logit, float_label))
self.roc_data_by_cnt[variant_type][count_bin_index].append((predicted_logit, float_label))

# return a list of tuples. This outer list is over the two labels, Call.SOMATIC and Call.ARTIFACT. Each tuple consists of
# (list of alt counts (x axis), list of accuracies (y axis), the label)
Expand All @@ -199,6 +208,36 @@ def make_data_for_calibration_plot(self, var_type: Variation):
str(multiple_of_three_bin_index_to_count(count_idx))) for count_idx in
range(NUM_COUNT_BINS)]

def make_logit_histograms(self):
fig, axes = plt.subplots(len(Variation), NUM_COUNT_BINS, sharex='all', sharey='all', squeeze=False,
figsize=(4 * NUM_COUNT_BINS, 4 * len(Variation)), dpi=200)

multiple_sources = len(self.all_sources) > 1
source_zero_line_colors = {Label.VARIANT: 'red', Label.ARTIFACT: 'magenta', Label.UNLABELED: 'limegreen'}
other_source_line_colors = {Label.VARIANT: 'darkred', Label.ARTIFACT: 'darkmagenta', Label.UNLABELED: 'darkgreen'}
for row, variation_type in enumerate(Variation):
for count_bin in range(NUM_COUNT_BINS): # this is also the column index
plot_data = self.logit_histogram_data_vcls[variation_type][count_bin]
different_labels = plot_data.keys()

# overlapping density plots for all source / label combinations
# source 0 is filled; others are not
ax = axes[row, count_bin]
for source in self.all_sources:
for label in different_labels:
line_label = f"{label.name} ({source})" if multiple_sources else label.name
color = source_zero_line_colors[label] if source == 0 else other_source_line_colors[label]

sns.kdeplot(data=np.clip(np.array(plot_data[label][source]), -10, 10), fill=(source == 0),
color=color, ax=ax, label=line_label, clip=(-10, 10))
ax.set_ylim(0, 0.5) # don't go all the way to 1 because
ax.legend()

column_names = [str(multiple_of_three_bin_index_to_count(count_idx)) for count_idx in range(NUM_COUNT_BINS)]
row_names = [var_type.name for var_type in Variation]
plotting.tidy_subplots(fig, axes, x_label="predicted logit", y_label="frequency", row_labels=row_names, column_labels=column_names)
return fig, axes

# now it's (list of logits, list of accuracies)
def make_data_for_calibration_plot_all_counts(self, var_type: Variation):
non_empty_logit_bins = [idx for idx in range(2 * MAX_LOGIT + 1) if not self.acc_vs_logit_all_counts[var_type][idx].is_empty()]
Expand Down Expand Up @@ -246,11 +285,10 @@ def __init__(self):
self.mistakes = []

# Variant is an IntEnum, so variant_type can also be integer
# label is 1 for artifact / error; 0 for non-artifact / true variant
# correct_call is boolean -- was the prediction correct?
# the predicted logit is the logit corresponding to the predicted probability that call in question is an artifact / error
def record_call(self, epoch_type: Epoch, variant_type: Variation, predicted_logit: float, label: float, correct_call, alt_count: int, weight: float = 1.0):
self.metrics[epoch_type].record_call(variant_type, predicted_logit, label, correct_call, alt_count, weight)
def record_call(self, epoch_type: Epoch, variant_type: Variation, predicted_logit: float, label: Label, correct_call, alt_count: int, weight: float = 1.0, source: int = 0):
self.metrics[epoch_type].record_call(variant_type, predicted_logit, label, correct_call, alt_count, weight, source=source)

# track bad calls when filtering is given an optional evaluation truth VCF
def record_mistake(self, posterior_result: PosteriorResult, call: Call):
Expand Down Expand Up @@ -343,6 +381,14 @@ def make_plots(self, summary_writer: SummaryWriter, given_thresholds=None, sens_
summary_writer.add_figure("sensitivity vs precision" if sens_prec else "variant accuracy vs artifact accuracy", roc_fig, global_step=epoch)
summary_writer.add_figure("sensitivity vs precision by alt count" if sens_prec else "variant accuracy vs artifact accuracy by alt count", roc_by_cnt_fig, global_step=epoch)

# one more plot, different from the rest. Here each epoch is its own figure, and within each figure the grid of subplots
# is by variant type and count. Within each subplot we have overlapping density plots of artifact logit predictions for all
# combinations of Label and source
for key in keys:
metric = self.metrics[key]
hist_fig, hist_ax = metric.make_logit_histograms()
summary_writer.add_figure(f"logit histograms ({Epoch(key).name})", hist_fig, global_step=epoch)


def sample_indices_for_tensorboard(indices: List[int]):
indices_np = np.array(indices)
Expand Down
2 changes: 1 addition & 1 deletion permutect/tools/filter_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def apply_filtering_to_vcf(input_vcf, output_vcf, contig_index_to_name_map, erro
# TODO: we stretch the definitions so that "Label.ARTIFACT" simply means "something we shouldn't call", including
# TODO: artifact or germline (in the somatic calling case), and "Label.VARIANT" means "something we should call"
is_correct = (called_as_error and label == Label.ARTIFACT) or (not called_as_error and label == Label.VARIANT)
evaluation_metrics.record_call(Epoch.TEST, variant_type, error_logit, float_label, is_correct, posterior_result.alt_count)
evaluation_metrics.record_call(Epoch.TEST, variant_type, error_logit, label, is_correct, posterior_result.alt_count)

# TODO: double-check the logic here
if is_correct:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ numpy >= 1.26.2
tqdm~=4.66.1
setuptools>=57.0.0
matplotlib~=3.8.2
seaborn>=0.13.2
pysam~=0.22.0
pandas~=2.1.3
cyvcf2~=0.30.15
Expand Down

0 comments on commit b757123

Please sign in to comment.