Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements and refactoring of loss weighting #148

Merged
merged 30 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a559383
work on getting a source integer in the data
davidbenjamin Oct 1, 2024
2ef5918
sources now added in preprocess dataset
davidbenjamin Oct 1, 2024
0ec8932
preprocessing WDL takes sources optionally
davidbenjamin Oct 1, 2024
cec6caa
record total losses merged into record by type and count
davidbenjamin Oct 1, 2024
ae4508e
evaluation_metrics.record_loss handles mixed labeled and unlabeled ba…
davidbenjamin Oct 2, 2024
cfa0fc1
artifact model learning ready for mixed labeled / unlabeled batches
davidbenjamin Oct 2, 2024
7909f26
did the same for BaseModel learning and also replaced the posterior p…
davidbenjamin Oct 2, 2024
89c6175
Removed references to batch.is_labeled in base batch embedding metrics
davidbenjamin Oct 2, 2024
5820e92
a little adjustment in filter variants
davidbenjamin Oct 2, 2024
32624bd
labeled-only data loader for pruning; only one use of is_labeled() left
davidbenjamin Oct 2, 2024
af813be
got the last is_labeled and maybe also fixed bug where pruning delete…
davidbenjamin Oct 2, 2024
4f71c0c
got rid of last references to base batch is_labeled
davidbenjamin Oct 2, 2024
da00bb4
base batch sampler mixes unlabeled and labeled
davidbenjamin Oct 2, 2024
94bea0f
artifact batch sampler also mixes unlabeled and labeled
davidbenjamin Oct 2, 2024
f8430c6
whoops
davidbenjamin Oct 2, 2024
775a9ca
whoops
davidbenjamin Oct 2, 2024
21380ad
BaseDataset tracks counts by source
davidbenjamin Oct 4, 2024
334d6d0
same for ARtifactDataset
davidbenjamin Oct 4, 2024
c946548
datasets record unlabeled totals by variant type and count
davidbenjamin Oct 4, 2024
b141ab3
some work for computing weights, including unlabeled being smarter
davidbenjamin Oct 7, 2024
a27386c
base dataset total counts stored in a single total dict
davidbenjamin Oct 7, 2024
981f2c9
calculate weights in base dataset
davidbenjamin Oct 7, 2024
a84fdc2
extract method for calculating batch weights
davidbenjamin Oct 7, 2024
de42606
using the extracted method
davidbenjamin Oct 7, 2024
8654f91
perhaps done with batch weights
davidbenjamin Oct 7, 2024
0a94f13
whoops
davidbenjamin Oct 8, 2024
18d3fc8
whoops
davidbenjamin Oct 8, 2024
876e3cb
debug print
davidbenjamin Oct 8, 2024
fd9df51
Whoops
davidbenjamin Oct 8, 2024
9bf9388
Experiment: always by count
davidbenjamin Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 51 additions & 66 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from itertools import chain
from matplotlib import pyplot as plt

from permutect.architecture.base_model import calculate_batch_weights
from permutect.architecture.mlp import MLP
from permutect.architecture.monotonic import MonoDense
from permutect.data.base_datum import ArtifactBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT
Expand Down Expand Up @@ -177,17 +178,10 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
train_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
train_optimizer, factor=0.2, patience=3, threshold=0.001, min_lr=(training_params.learning_rate / 100),verbose=True)

artifact_to_non_artifact_ratios = torch.from_numpy(dataset.artifact_to_non_artifact_ratios()).to(self._device)

# balance training by weighting the loss function
# if total unlabeled is less than total labeled, we do not compensate, since labeled data are more informative
total_labeled, total_unlabeled = dataset.total_labeled_and_unlabeled()
labeled_to_unlabeled_ratio = 1 if total_unlabeled < total_labeled else total_labeled / total_unlabeled

print(f"Training data contains {total_labeled:.0f} labeled examples and {total_unlabeled:.0f} unlabeled examples")
for variation_type in utils.Variation:
idx = variation_type.value
print(f"For variation type {variation_type.name} there are {int(dataset.artifact_totals[idx].item())} labeled artifact examples and {int(dataset.non_artifact_totals[idx].item())} labeled non-artifact examples")
for idx, variation_type in enumerate(utils.Variation):
print(f"For variation type {variation_type.name}, there are {int(dataset.totals[-1][Label.ARTIFACT][idx].item())} \
artifacts, {int(dataset.totals[-1][Label.VARIANT][idx].item())} \
non-artifacts, and {int(dataset.totals[-1][Label.UNLABELED][idx].item())} unlabeled data.")

validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold
train_loader = dataset.make_data_loader(dataset.all_but_one_fold(validation_fold_to_use), training_params.batch_size, self._device.type == 'cuda', training_params.num_workers)
Expand All @@ -208,58 +202,43 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
utils.unfreeze(self.calibration_parameters()) # unfreeze calibration but everything else stays frozen

loss_metrics = LossMetrics(self._device) # based on calibrated logits
precalibrated_loss_metrics = LossMetrics(self._device) # based on precalibrated logits
uncalibrated_loss_metrics = LossMetrics(self._device) # based on uncalibrated logits

loader = train_loader if epoch_type == utils.Epoch.TRAIN else valid_loader
pbar = tqdm(enumerate(loader), mininterval=60)
for n, batch in pbar:
logits, precalibrated_logits = self.forward(batch)
types_one_hot = batch.variant_type_one_hot()

# if it's a calibration epoch, get count- and variant type-dependent artifact ratio; otherwise it depends only on type
if is_calibration_epoch:
ratios = [dataset.artifact_to_non_artifact_ratios_by_count(alt_count)[var_type] for alt_count, var_type in zip(batch.alt_counts, batch.variant_types())]
non_artifact_weights = torch.FloatTensor(ratios)
else:
non_artifact_weights = torch.sum(artifact_to_non_artifact_ratios * types_one_hot, dim=1)

if batch.is_labeled():
# maintain the interpretation of the logits as a likelihood ratio by weighting to effectively
# achieve a balanced data set eg equal prior between artifact and non-artifact
# for artifacts, weight is 1; for non-artifacts it's artifact to nonartifact ratio
weights = batch.labels + (1 - batch.labels) * non_artifact_weights

separate_losses_calibrated = weights * bce(logits, batch.labels)

separate_losses_precalibrated = weights * bce(precalibrated_logits, batch.labels)
calibrated_loss = torch.sum(separate_losses_calibrated)
precalibrated_loss = torch.sum(separate_losses_precalibrated)
loss = calibrated_loss + precalibrated_loss

loss_metrics.record_total_batch_loss(calibrated_loss.detach(), batch, weights)
loss_metrics.record_losses_by_type_and_count(separate_losses_calibrated.detach(), batch)
precalibrated_loss_metrics.record_total_batch_loss(precalibrated_loss.detach(), batch, weights)
precalibrated_loss_metrics.record_losses_by_type_and_count(separate_losses_precalibrated.detach(), batch)
# calibration epochs freeze the model up to calibration, so the unlabeled loss is irrelevant
else:
# unlabeled loss: entropy regularization
# Note that we use the precalibrated logits because otherwise entropy regularization simply biases
# calibration to be overconfident.
probabilities = torch.sigmoid(precalibrated_logits)
entropies = torch.nn.functional.binary_cross_entropy_with_logits(precalibrated_logits, probabilities, reduction='none')
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
# TODO: this interacts with the artifact / non-artifact weighting of labeled data!!!
loss = torch.sum(entropies) * labeled_to_unlabeled_ratio
precalibrated_loss_metrics.record_total_batch_loss(loss.detach(), batch)

# I don't get this next line: loss is a sum, not a mean, so it's already weighted by size!!!
# batch_weight = batch.size() / training_params.batch_size
if epoch_type == utils.Epoch.TRAIN and not (is_calibration_epoch and not batch.is_labeled()):

# TODO: maybe this should be done by count for all epochs?
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
weights = calculate_batch_weights(batch, dataset, by_count=is_calibration_epoch)

uncalibrated_cross_entropies = bce(precalibrated_logits, batch.labels)
calibrated_cross_entropies = bce(logits, batch.labels)
labeled_losses = batch.is_labeled_mask * (uncalibrated_cross_entropies + calibrated_cross_entropies) / 2

# unlabeled loss: entropy regularization. We use the uncalibrated logits because otherwise entropy
# regularization simply biases calibration to be overconfident.
probabilities = torch.sigmoid(precalibrated_logits)
entropies = torch.nn.functional.binary_cross_entropy_with_logits(precalibrated_logits, probabilities, reduction='none')
unlabeled_losses = (1 - batch.is_labeled_mask) * entropies

# these losses include weights and take labeled vs unlabeled into account
losses = (labeled_losses + unlabeled_losses) * weights
loss = torch.sum(losses)

loss_metrics.record_losses(calibrated_cross_entropies.detach(), batch, weights * batch.is_labeled_mask)
uncalibrated_loss_metrics.record_losses(uncalibrated_cross_entropies.detach(), batch, weights * batch.is_labeled_mask)
uncalibrated_loss_metrics.record_losses(entropies.detach(), batch, weights * (1 - batch.is_labeled_mask))

# calibration epochs freeze the model up to calibration, so I wonder if a purely unlabeled batch
# would cause lack of gradient problems. . .
if epoch_type == utils.Epoch.TRAIN:
utils.backpropagate(train_optimizer, loss)

# done with one epoch type -- training or validation -- for this epoch
loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer)
precalibrated_loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="uncalibrated")
uncalibrated_loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="uncalibrated")
if epoch_type == utils.Epoch.TRAIN:
train_scheduler.step(loss_metrics.get_labeled_loss())

Expand Down Expand Up @@ -304,28 +283,28 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
# the keys are tuples of (true label -- 1 for variant, 0 for artifact; rounded alt count)
worst_offenders_by_truth_and_alt_count = defaultdict(lambda: PriorityQueue(WORST_OFFENDERS_QUEUE_SIZE))

artifact_to_non_artifact_ratios = torch.from_numpy(dataset.artifact_to_non_artifact_ratios())
evaluation_metrics = EvaluationMetrics()
epoch_types = [Epoch.TRAIN, Epoch.VALID]
for epoch_type in epoch_types:
assert epoch_type == Epoch.TRAIN or epoch_type == Epoch.VALID # not doing TEST here
loader = train_loader if epoch_type == Epoch.TRAIN else valid_loader
pbar = tqdm(enumerate(filter(lambda bat: bat.is_labeled(), loader)), mininterval=60)
pbar = tqdm(enumerate(loader), mininterval=60)
for n, batch in pbar:

# these are the same weights used in training to effectively balance the data between artifact and
# non-artifact for each variant type
types_one_hot = batch.variant_type_one_hot()
non_artifact_weights = torch.sum(artifact_to_non_artifact_ratios * types_one_hot, dim=1)
weights = batch.labels + (1 - batch.labels) * non_artifact_weights
# these are the same weights used in training
# TODO: maybe this should be done by count?
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
weights = calculate_batch_weights(batch, dataset, by_count=True)

logits, _ = self.forward(batch)
pred = logits.detach()
correct = ((pred > 0) == (batch.labels > 0.5)).tolist()

for variant_type, predicted_logit, label, correct_call, alt_count, datum, weight in zip(
batch.variant_types(), pred.tolist(), batch.labels.tolist(), correct,
for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, datum, weight in zip(
batch.variant_types(), pred.tolist(), batch.labels.tolist(), batch.is_labeled_mask.tolist(), correct,
batch.alt_counts, batch.original_data, weights.tolist()):
if is_labeled < 0.5: # we only evaluate labeled data
continue
evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call, alt_count, weight)
if report_worst and not correct_call:
rounded_count = round_up_to_nearest_three(alt_count)
Expand Down Expand Up @@ -370,17 +349,23 @@ def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, val
ref_alt_seq_metrics = EmbeddingMetrics()

# now go over just the validation data and generate feature vectors / metadata for tensorboard projectors (UMAP)
pbar = tqdm(enumerate(filter(lambda bat: bat.is_labeled(), valid_loader)), mininterval=60)
pbar = tqdm(enumerate(valid_loader), mininterval=60)

for n, batch in pbar:
logits, _ = self.forward(batch)
pred = logits.detach()
correct = ((pred > 0) == (batch.labels > 0.5)).tolist()

label_strings = [("artifact" if label > 0.5 else "non-artifact") if is_labeled > 0.5 else "unlabeled"
for (label, is_labeled) in zip(batch.labels.tolist(), batch.is_labeled_mask.tolist())]

correct_strings = [str(correctness) if is_labeled > 0.5 else "-1"
for (correctness, is_labeled) in zip(correct, batch.is_labeled_mask.tolist())]

for (metrics, embedding) in [(embedding_metrics, batch.get_representations_2d().detach()),
(ref_alt_seq_metrics, batch.get_ref_alt_seq_embeddings_2d().detach())]:
metrics.label_metadata.extend(["artifact" if x > 0.5 else "non-artifact" for x in batch.labels.tolist()])
metrics.correct_metadata.extend([str(val) for val in correct])
metrics.label_metadata.extend(label_strings)
metrics.correct_metadata.extend(correct_strings)
metrics.type_metadata.extend([Variation(idx).name for idx in batch.variant_types()])
metrics.truncated_count_metadata.extend([str(round_up_to_nearest_three(min(MAX_COUNT, alt_count))) for alt_count in batch.alt_counts])
metrics.representations.append(embedding)
Expand Down
Loading