Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unlabeled Domain Adaptation #152

Merged
merged 19 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 65 additions & 15 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# bug before PyTorch 1.7.1 that warns when constructing ParameterList
import math
import warnings
from collections import defaultdict
from typing import List
Expand All @@ -15,7 +16,8 @@
from itertools import chain
from matplotlib import pyplot as plt

from permutect.architecture.base_model import calculate_batch_weights, BaseModel, base_model_from_saved_dict
from permutect.architecture.base_model import calculate_batch_weights, BaseModel, base_model_from_saved_dict, calculate_batch_source_weights
from permutect.architecture.gradient_reversal.module import GradientReversal
from permutect.architecture.mlp import MLP
from permutect.architecture.monotonic import MonoDense
from permutect.data.base_datum import ArtifactBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT
Expand Down Expand Up @@ -137,16 +139,20 @@ def __init__(self, params: ArtifactModelParameters, num_base_features: int, num_
self.num_ref_alt_features = num_ref_alt_features
self.params = params

# feature layers before the domain adaptation source classifier splits from the artifact classifier
self.feature_layers = MLP([num_base_features] + params.aggregation_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p)

# TODO: artifact classifier hidden layers are hard-coded!!!
# The [1] is for the output logit
self.aggregation = MLP([num_base_features] + params.aggregation_layers + [1], batch_normalize=params.batch_normalize, dropout_p=params.dropout_p)
self.artifact_classifier = MLP([self.feature_layers.output_dimension()] + [-1, -1, 1], batch_normalize=params.batch_normalize, dropout_p=params.dropout_p)

# one Calibration module for each variant type; that is, calibration depends on both count and type
self.calibration = nn.ModuleList([Calibration(params.calibration_layers) for variant_type in Variation])

self.to(device=self._device, dtype=self._dtype)

def training_parameters(self):
return chain(self.aggregation.parameters(), self.calibration.parameters())
return chain(self.feature_layers.parameters(), self.artifact_classifier.parameters(), self.calibration.parameters())

def calibration_parameters(self):
return self.calibration.parameters()
Expand All @@ -164,19 +170,38 @@ def set_epoch_type(self, epoch_type: utils.Epoch):

# returns 1D tensor of length batch_size of log odds ratio (logits) between artifact and non-artifact
def forward(self, batch: ArtifactBatch):
precalibrated_logits = self.aggregation.forward(batch.get_representations_2d().to(device=self._device, dtype=self._dtype)).reshape(batch.size())
calibrated_logits = torch.zeros_like(precalibrated_logits)
features = self.feature_layers.forward(batch.get_representations_2d().to(device=self._device, dtype=self._dtype))
uncalibrated_logits = self.artifact_classifier.forward(features).reshape(batch.size())
calibrated_logits = torch.zeros_like(uncalibrated_logits)
one_hot_types_2d = batch.variant_type_one_hot().to(device=self._device, dtype=self._dtype)
for n, _ in enumerate(Variation):
mask = one_hot_types_2d[:, n]
calibrated_logits += mask * self.calibration[n].forward(precalibrated_logits, batch.ref_counts, batch.alt_counts)
return calibrated_logits, precalibrated_logits
calibrated_logits += mask * self.calibration[n].forward(uncalibrated_logits, batch.ref_counts, batch.alt_counts)
return calibrated_logits, uncalibrated_logits, features

def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, summary_writer: SummaryWriter, validation_fold: int = None, epochs_per_evaluation: int = None):
bce = nn.BCEWithLogitsLoss(reduction='none') # no reduction because we may want to first multiply by weights for unbalanced data
train_optimizer = torch.optim.AdamW(self.training_parameters(), lr=training_params.learning_rate, weight_decay=training_params.weight_decay)
train_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
train_optimizer, factor=0.2, patience=3, threshold=0.001, min_lr=(training_params.learning_rate / 100),verbose=True)
# cross entropy (with logit inputs) loss for adversarial source classification task
ce = nn.CrossEntropyLoss(reduction='none')

num_sources = len(dataset.counts_by_source.keys())
if num_sources == 1:
print("Training data come from a single source (this could be multiple files with the same source annotation applied in preprocessing)")
else:
sources_list = list(dataset.counts_by_source.keys())
sources_list.sort()
assert sources_list[0] == 0, "There is no source 0"
assert sources_list[-1] == num_sources - 1, f"sources should be 0, 1, 2. . . without gaps, but sources are {sources_list}."

print(f"Training data come from multiple sources, with counts {dataset.counts_by_source}.")
source_classifier = MLP([self.feature_layers.output_dimension()] + [-1, -1, num_sources],
batch_normalize=self.params.batch_normalize, dropout_p=self.params.dropout_p)
source_gradient_reversal = GradientReversal(alpha=0.01) # initialize as barely active

train_optimizer = torch.optim.AdamW(chain(self.training_parameters(), source_classifier.parameters()), lr=training_params.learning_rate,
weight_decay=training_params.weight_decay)
train_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(train_optimizer, factor=0.2, patience=3,
threshold=0.001, min_lr=(training_params.learning_rate / 100), verbose=True)

for idx, variation_type in enumerate(utils.Variation):
print(f"For variation type {variation_type.name}, there are {int(dataset.totals[-1][Label.ARTIFACT][idx].item())} \
Expand All @@ -194,6 +219,10 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
print(f"Epoch {epoch}, memory usage percent: {psutil.virtual_memory().percent:.1f}")
is_calibration_epoch = epoch > training_params.num_epochs

p = epoch - 1
new_alpha = (2 / (1 + math.exp(-0.1 * p))) - 1
source_gradient_reversal.set_alpha(new_alpha)

for epoch_type in [utils.Epoch.TRAIN, utils.Epoch.VALID]:
self.set_epoch_type(epoch_type)
# in calibration epoch, freeze the model except for calibration
Expand All @@ -202,16 +231,31 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
utils.unfreeze(self.calibration_parameters()) # unfreeze calibration but everything else stays frozen

loss_metrics = LossMetrics(self._device) # based on calibrated logits
source_prediction_loss_metrics = LossMetrics(self._device) # based on calibrated logits
uncalibrated_loss_metrics = LossMetrics(self._device) # based on uncalibrated logits

loader = train_loader if epoch_type == utils.Epoch.TRAIN else valid_loader
pbar = tqdm(enumerate(loader), mininterval=60)
for n, batch in pbar:
logits, precalibrated_logits = self.forward(batch)
logits, precalibrated_logits, features = self.forward(batch)

# one-hot prediction of sources
if num_sources > 1:
# gradient reversal means parameters before the features try to maximize source prediction loss, i.e. features
# try to forget the source, while parameters after the features try to minimize it, i.e. they try
# to achieve the adversarial task of distinguishing sources
source_prediction_logits = source_classifier.forward(source_gradient_reversal(features))
source_prediction_probs = torch.nn.functional.softmax(source_prediction_logits, dim=-1)
source_prediction_targets = torch.nn.functional.one_hot(batch.sources.to(device=self._device).long(), num_sources)
source_prediction_losses = torch.sum(torch.square(source_prediction_probs - source_prediction_targets), dim=-1)
source_prediction_weights = calculate_batch_source_weights(batch, dataset, by_count=is_calibration_epoch)
else:
source_prediction_losses = torch.zeros_like(logits)
source_prediction_weights = torch.zeros_like(logits)

# TODO: maybe this should be done by count for all epochs?
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
weights = calculate_batch_weights(batch, dataset, by_count=is_calibration_epoch)
weights = calculate_batch_weights(batch, dataset, by_count=True)

uncalibrated_cross_entropies = bce(precalibrated_logits, batch.labels)
calibrated_cross_entropies = bce(logits, batch.labels)
Expand All @@ -224,12 +268,13 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
unlabeled_losses = (1 - batch.is_labeled_mask) * entropies

# these losses include weights and take labeled vs unlabeled into account
losses = (labeled_losses + unlabeled_losses) * weights
losses = (labeled_losses + unlabeled_losses) * weights + (source_prediction_losses * source_prediction_weights)
loss = torch.sum(losses)

loss_metrics.record_losses(calibrated_cross_entropies.detach(), batch, weights * batch.is_labeled_mask)
uncalibrated_loss_metrics.record_losses(uncalibrated_cross_entropies.detach(), batch, weights * batch.is_labeled_mask)
uncalibrated_loss_metrics.record_losses(entropies.detach(), batch, weights * (1 - batch.is_labeled_mask))
source_prediction_loss_metrics.record_losses(source_prediction_losses.detach(), batch, source_prediction_weights)

# calibration epochs freeze the model up to calibration, so I wonder if a purely unlabeled batch
# would cause lack of gradient problems. . .
Expand All @@ -238,11 +283,16 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s

# done with one epoch type -- training or validation -- for this epoch
loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer)
source_prediction_loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="source prediction")
uncalibrated_loss_metrics.write_to_summary_writer(epoch_type, epoch, summary_writer, prefix="uncalibrated")
if epoch_type == utils.Epoch.TRAIN:
train_scheduler.step(loss_metrics.get_labeled_loss())

print(f"Labeled loss for {epoch_type.name} epoch {epoch}: {loss_metrics.get_labeled_loss():.3f}")
print(f"Unlabeled loss for {epoch_type.name} epoch {epoch}: {uncalibrated_loss_metrics.get_unlabeled_loss():.3f}")
if num_sources > 1:
print(f"Adversarial source prediction loss on labeled data for {epoch_type.name} epoch {epoch}: {source_prediction_loss_metrics.get_labeled_loss():.3f}")
print(f"Adversarial source prediction loss on unlabeled data for {epoch_type.name} epoch {epoch}: {source_prediction_loss_metrics.get_unlabeled_loss():.3f}")
# done with training and validation for this epoch
is_last = (epoch == last_epoch)
if (epochs_per_evaluation is not None and epoch % epochs_per_evaluation == 0) or is_last:
Expand Down Expand Up @@ -296,7 +346,7 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
weights = calculate_batch_weights(batch, dataset, by_count=True)

logits, _ = self.forward(batch)
logits, _, _ = self.forward(batch)
pred = logits.detach()
correct = ((pred > 0) == (batch.labels > 0.5)).tolist()

Expand Down Expand Up @@ -352,7 +402,7 @@ def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, val
pbar = tqdm(enumerate(valid_loader), mininterval=60)

for n, batch in pbar:
logits, _ = self.forward(batch)
logits, _, _ = self.forward(batch)
pred = logits.detach()
correct = ((pred > 0) == (batch.labels > 0.5)).tolist()

Expand Down
24 changes: 11 additions & 13 deletions permutect/architecture/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def sums_over_chunks(tensor2d: torch.Tensor, chunk_size: int):
# if by_count is True, each count is weighted separately for balanced loss within that count
def calculate_batch_weights(batch, dataset, by_count: bool):
# -1 is the sentinel value for aggregation over all counts
# TODO: maybe this should be done by count?
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
types_one_hot = batch.variant_type_one_hot()
weights_by_label_and_type = {label: (np.vstack([dataset.weights[count][label] for count in batch.alt_counts.tolist()]) if \
Expand All @@ -47,6 +46,17 @@ def calculate_batch_weights(batch, dataset, by_count: bool):
return weights


# note: this works for both BaseBatch/BaseDataset AND ArtifactBatch/ArtifactDataset
# if by_count is True, each count is weighted separately for balanced loss within that count
def calculate_batch_source_weights(batch, dataset, by_count: bool):
# -1 is the sentinel value for aggregation over all counts
types_one_hot = batch.variant_type_one_hot()
weights_by_type = np.vstack([dataset.weights[count if by_count else -1][source] for count, source in zip(batch.alt_counts.tolist(), batch.sources.tolist())])
source_weights = torch.sum(torch.from_numpy(weights_by_type) * types_one_hot, dim=1)

return source_weights


class LearningMethod(Enum):
# train the embedding by minimizing cross-entropy loss of binary predictor on labeled data
SUPERVISED = "SUPERVISED"
Expand Down Expand Up @@ -538,18 +548,6 @@ def learn_base_model(base_model: BaseModel, dataset: BaseDataset, learning_metho
classification_loss = torch.sum(batch.is_labeled_mask * weights * classification_losses)
classifier_metrics.record_losses(classification_losses.detach(), batch, batch.is_labeled_mask * weights)

# STUPID DEBUG STUFF
if n == 2:
print(f"actual alt counts {batch.alt_counts.tolist()}")
print(f"alt count predictions: {alt_count_pred.detach().tolist()}")
print(f"alt count losses {alt_count_losses.detach().tolist()}")
print(f"weights {weights.tolist()}")
print(f"semisupervised losses {losses.detach().tolist()}")
print(f"classification logits {classification_logits.detach().tolist()}")
print(f"classification losses {classification_losses.detach().tolist()}")

# DONE DEBUG

if epoch_type == utils.Epoch.TRAIN:
utils.backpropagate(train_optimizer, loss)
utils.backpropagate(classifier_optimizer, classification_loss)
Expand Down
1 change: 1 addition & 0 deletions permutect/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
NUM_REF_ALT_FEATURES_NAME = 'num_ref_alt_features'

SOURCES_NAME = 'sources'
SOURCE_NAME = 'source'

INPUT_NAME = 'input'
OUTPUT_NAME = 'output'
Expand Down
2 changes: 2 additions & 0 deletions permutect/data/artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class ArtifactDataset(Dataset):
def __init__(self, base_dataset: BaseDataset, base_model: BaseModel, folds_to_use: List[int] = None):
self.counts_by_source = base_dataset.counts_by_source
self.totals = base_dataset.totals
self.source_totals = base_dataset.source_totals
self.weights = base_dataset.weights
self.source_weights = base_dataset.source_weights

self.artifact_data = []
self.num_folds = base_dataset.num_folds
Expand Down
Loading