Skip to content

Commit

Permalink
Refactor data tensors (#168)
Browse files Browse the repository at this point in the history
* variant type one hot no longer in INFO since str info includes that

* get rid of all variant_type_one_hot methods in favor of using the integer variant type

* merge BaseDatum1DStuff and ArtifactDatum1DStuff into a single OneDimensionalData class

* simplified calculation of batch weights and source weights 

* fixed a bug where artifact model training was using the regular weights instead of the source weights

* got rid of an implicit dim warning
  • Loading branch information
davidbenjamin authored Dec 6, 2024
1 parent 79773ba commit 00eda79
Show file tree
Hide file tree
Showing 20 changed files with 394 additions and 463 deletions.
25 changes: 14 additions & 11 deletions permutect/architecture/artifact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def forward(self, batch: ArtifactBatch):
features = self.feature_layers.forward(batch.get_representations_2d())
uncalibrated_logits = self.artifact_classifier.forward(features).reshape(batch.size())
calibrated_logits = torch.zeros_like(uncalibrated_logits, device=self._device)
one_hot_types_2d = batch.variant_type_one_hot()
variant_types = batch.get_variant_types()
for n, _ in enumerate(Variation):
mask = one_hot_types_2d[:, n]
mask = (variant_types == n)
calibrated_logits += mask * self.calibration[n].forward(uncalibrated_logits, batch.get_ref_counts(), batch.get_alt_counts())
return calibrated_logits, uncalibrated_logits, features

Expand Down Expand Up @@ -289,8 +289,9 @@ def learn(self, dataset: ArtifactDataset, training_params: TrainingParameters, s
weights = calculate_batch_weights(batch_cpu, dataset, by_count=True)
weights = weights.to(device=self._device, dtype=self._dtype, non_blocking=True)

uncalibrated_cross_entropies = bce(precalibrated_logits, batch.labels)
calibrated_cross_entropies = bce(logits, batch.labels)
labels = batch.get_training_labels()
uncalibrated_cross_entropies = bce(precalibrated_logits, labels)
calibrated_cross_entropies = bce(logits, labels)
labeled_losses = batch.get_is_labeled_mask() * (uncalibrated_cross_entropies + calibrated_cross_entropies) / 2

# unlabeled loss: entropy regularization. We use the uncalibrated logits because otherwise entropy
Expand Down Expand Up @@ -389,11 +390,12 @@ def collect_evaluation_data(self, dataset: ArtifactDataset, train_loader, valid_
pred = logits.detach().cpu()

# note that for metrics we use batch_cpu
correct = ((pred > 0) == (batch_cpu.labels > 0.5)).tolist()
labels = batch_cpu.get_training_labels()
correct = ((pred > 0) == (labels > 0.5)).tolist()

for variant_type, predicted_logit, label, is_labeled, correct_call, alt_count, variant, weight in zip(
batch_cpu.variant_types(), pred.tolist(), batch_cpu.labels.tolist(), batch_cpu.get_is_labeled_mask().tolist(), correct,
batch_cpu.get_alt_counts(), batch_cpu.get_variants(), weights.tolist()):
batch_cpu.get_variant_types().tolist(), pred.tolist(), labels.tolist(), batch_cpu.get_is_labeled_mask().tolist(), correct,
batch_cpu.get_alt_counts().tolist(), batch_cpu.get_variants(), weights.tolist()):
if is_labeled < 0.5: # we only evaluate labeled data
continue
evaluation_metrics.record_call(epoch_type, variant_type, predicted_logit, label, correct_call, alt_count, weight)
Expand Down Expand Up @@ -445,19 +447,20 @@ def evaluate_model(self, epoch: int, dataset: ArtifactDataset, train_loader, val
batch = batch_cpu.copy_to(self._device, self._dtype, non_blocking=self._device.type == 'cuda')
logits, _, _ = self.forward(batch)
pred = logits.detach().cpu()
correct = ((pred > 0) == (batch_cpu.labels > 0.5)).tolist()
labels = batch_cpu.get_training_labels()
correct = ((pred > 0) == (labels > 0.5)).tolist()

label_strings = [("artifact" if label > 0.5 else "non-artifact") if is_labeled > 0.5 else "unlabeled"
for (label, is_labeled) in zip(batch_cpu.labels.tolist(), batch_cpu.get_is_labeled_mask().tolist())]
for (label, is_labeled) in zip(labels.tolist(), batch_cpu.get_is_labeled_mask().tolist())]

correct_strings = [str(correctness) if is_labeled > 0.5 else "-1"
for (correctness, is_labeled) in zip(correct, batch_cpu.get_is_labeled_mask().tolist())]

for (metrics, embedding) in [(embedding_metrics, batch_cpu.get_representations_2d().detach())]:
metrics.label_metadata.extend(label_strings)
metrics.correct_metadata.extend(correct_strings)
metrics.type_metadata.extend([Variation(idx).name for idx in batch_cpu.variant_types()])
metrics.truncated_count_metadata.extend([str(round_up_to_nearest_three(min(MAX_COUNT, alt_count))) for alt_count in batch_cpu.get_alt_counts()])
metrics.type_metadata.extend([Variation(idx).name for idx in batch_cpu.get_variant_types().tolist()])
metrics.truncated_count_metadata.extend([str(round_up_to_nearest_three(min(MAX_COUNT, alt_count))) for alt_count in batch_cpu.get_alt_counts().tolist()])
metrics.representations.append(embedding)
embedding_metrics.output_to_summary_writer(summary_writer, epoch=epoch)
# done collecting data
Expand Down
43 changes: 21 additions & 22 deletions permutect/architecture/artifact_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from permutect import utils
from torch import nn
from torch import nn, IntTensor

from permutect.metrics.plotting import simple_plot
from permutect.utils import beta_binomial, Variation
Expand Down Expand Up @@ -47,43 +47,42 @@ def __init__(self, num_components: int):
# for each component and variant type:
# alpha = exp(alpha0_pre_exp - exp(eta_pre_exp)*sigmoid(depth * exp(delta_pre_exp)))



'''
here x is a 2D tensor, 1st dimension batch, 2nd dimension being features that determine which Beta mixture to use
n and k are 1D tensors, the only dimension being batch.
'''
def forward(self, types_one_hot_bv, depths_b, alt_counts_b):
alt_counts_bk = torch.unsqueeze(alt_counts_b, dim=1).expand(-1, self.K - 1)
depths_bk = torch.unsqueeze(depths_b, dim=1).expand(-1, self.K - 1)
depths_bvk = depths_bk[:, None, :]
def forward(self, variant_types_b: torch.IntTensor, depths_b, alt_counts_b):
# indexing convention: b is batch, v is variant type, k is cluster component
alt_counts_bk = alt_counts_b[:, None]
depths_bk = depths_b[:, None]

self.alpha0_pre_exp_vk
eta_vk = torch.exp(self.eta_pre_exp_vk)
delta_vk = torch.exp(self.delta_pre_exp_vk)
alpha0_pre_exp_bvk, eta_bvk, delta_bvk = self.alpha0_pre_exp_vk[None, :, :], eta_vk[None, :, :], delta_vk[None, :, :]
weights0_pre_softmax_bvk, gamma_bvk, kappa_bvk = self.weights0_pre_softmax_vk[None, :, :], self.gamma_vk[None, :, :], self.kappa_vk[None, :, :]

alpha_bvk = torch.exp(alpha0_pre_exp_bvk - eta_bvk * torch.sigmoid(depths_bvk * delta_bvk))

types_one_hot_bvk = torch.unsqueeze(types_one_hot_bv, dim=-1) # gives it broadcastable length-1 component dimension
alpha_bk = torch.sum(types_one_hot_bvk * alpha_bvk, dim=1) # due to one-hotness only one v contributes to the sum
var_types_b = variant_types_b.long()
alpha0_pre_exp_bk = self.alpha0_pre_exp_vk[var_types_b, :]
delta_bk = delta_vk[var_types_b, :]
eta_bk = eta_vk[var_types_b, :]
alpha_bk = torch.exp(alpha0_pre_exp_bk - eta_bk * torch.sigmoid(depths_bk * delta_bk))
beta_bk = self.beta * torch.ones_like(alpha_bk)

beta_binomial_likelihoods_bk = beta_binomial(depths_bk, alt_counts_bk, alpha_bk, beta_bk)

weights_pre_softmax_bvk = weights0_pre_softmax_bvk + gamma_bvk * torch.sigmoid(depths_bvk * kappa_bvk)
if alpha_bk.isnan().any():
print("NaN found in alpha_bk")
assert 1 < 0, "FAIL"

log_weights_bvk = torch.log_softmax(weights_pre_softmax_bvk, dim=-1) # softmax over component dimension
log_weights_bk = torch.sum(types_one_hot_bvk * log_weights_bvk, dim=1) # same idea as above
weights0_pre_softmax_bk = self.weights0_pre_softmax_vk[var_types_b, :]
gamma_bk = self.gamma_vk[var_types_b, :]
kappa_bk = self.kappa_vk[var_types_b, :]
weights_pre_softmax_bk = weights0_pre_softmax_bk + gamma_bk * torch.sigmoid(depths_bk * kappa_bk)
log_weights_bk = torch.log_softmax(weights_pre_softmax_bk, dim=-1) # softmax over component dimension

weighted_likelihoods_bk = log_weights_bk + beta_binomial_likelihoods_bk

result_b = torch.logsumexp(weighted_likelihoods_bk, dim=1, keepdim=False)
result_b = torch.logsumexp(weighted_likelihoods_bk, dim=-1, keepdim=False)
return result_b

# TODO: utter code duplication with somatic spectrum
def fit(self, num_epochs, types_one_hot_2d, depths_1d_tensor, alt_counts_1d_tensor, batch_size=64):
def fit(self, num_epochs, types_b: IntTensor, depths_1d_tensor, alt_counts_1d_tensor, batch_size=64):
optimizer = torch.optim.Adam(self.parameters())
num_batches = math.ceil(len(alt_counts_1d_tensor) / batch_size)

Expand All @@ -92,7 +91,7 @@ def fit(self, num_epochs, types_one_hot_2d, depths_1d_tensor, alt_counts_1d_tens
batch_start = batch * batch_size
batch_end = min(batch_start + batch_size, len(alt_counts_1d_tensor))
batch_slice = slice(batch_start, batch_end)
loss = -torch.mean(self.forward(types_one_hot_2d[batch_slice], depths_1d_tensor[batch_slice], alt_counts_1d_tensor[batch_slice]))
loss = -torch.mean(self.forward(types_b[batch_slice], depths_1d_tensor[batch_slice], alt_counts_1d_tensor[batch_slice]))
utils.backpropagate(optimizer, loss)

'''
Expand Down
45 changes: 23 additions & 22 deletions permutect/architecture/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,26 @@ def sums_over_chunks(tensor2d: torch.Tensor, chunk_size: int):
# note: this works for both BaseBatch/BaseDataset AND ArtifactBatch/ArtifactDataset
# if by_count is True, each count is weighted separately for balanced loss within that count
def calculate_batch_weights(batch, dataset, by_count: bool):
# -1 is the sentinel value for aggregation over all counts
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
types_one_hot = batch.variant_type_one_hot()
weights_by_label_and_type = {label: (np.vstack([dataset.weights[count][label] for count in batch.get_alt_counts().tolist()]) if \
by_count else dataset.weights[-1][label]) for label in Label}
weights_by_label = {label: torch.sum(torch.from_numpy(weights_by_label_and_type[label]) * types_one_hot, dim=1) for label in Label}
weights = batch.get_is_labeled_mask() * (batch.labels * weights_by_label[Label.ARTIFACT] + (1 - batch.labels) * weights_by_label[Label.VARIANT]) + \
(1 - batch.get_is_labeled_mask()) * weights_by_label[Label.UNLABELED]
return weights
# For batch index n, we want weight[n] = dataset.weights[alt_counts[n], labels[n], variant_types[n]]
counts = batch.get_alt_counts()
labels = batch.get_labels()
variant_types = batch.get_variant_types()

return utils.index_3d_array(dataset.weights, counts, labels, variant_types) if by_count else \
utils.index_2d_array(dataset.weights[ALL_COUNTS_SENTINEL], labels, variant_types)


# note: this works for both BaseBatch/BaseDataset AND ArtifactBatch/ArtifactDataset
# if by_count is True, each count is weighted separately for balanced loss within that count
def calculate_batch_source_weights(batch, dataset, by_count: bool):
# -1 is the sentinel value for aggregation over all counts
types_one_hot = batch.variant_type_one_hot()
weights_by_type = np.vstack([dataset.weights[count if by_count else -1][source] for count, source in zip(batch.get_alt_counts().tolist(), batch.get_sources().tolist())])
source_weights = torch.sum(torch.from_numpy(weights_by_type).to(device=types_one_hot.device) * types_one_hot, dim=1)
# For batch index n, we want weight[n] = dataset.source_weights[alt_counts[n], sources[n], variant_types[n]]
counts = batch.get_alt_counts()
sources = batch.get_sources()
variant_types = batch.get_variant_types()

return source_weights
return utils.index_3d_array(dataset.source_weights, counts, sources, variant_types) if by_count else \
utils.index_2d_array(dataset.source_weights[ALL_COUNTS_SENTINEL], sources, variant_types)


class LearningMethod(Enum):
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(self, batch: BaseBatch):
# so, for example, "re" means a 2D tensor with all reads in the batch stacked and "vre" means a 3D tensor indexed
# first by variant within the batch, then the read
def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) -> torch.Tensor:
ref_counts, alt_counts = batch.ref_counts, batch.alt_counts
ref_counts, alt_counts = batch.get_ref_counts(), batch.get_alt_counts()
total_ref, total_alt = torch.sum(ref_counts).item(), torch.sum(alt_counts).item()

read_embeddings_re = self.read_embedding.forward(batch.get_reads_2d().to(dtype=self._dtype))
Expand Down Expand Up @@ -239,13 +239,14 @@ def __init__(self, input_dim: int, hidden_top_layers: List[int], params: BaseMod

def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model_representations: torch.Tensor):
logits = self.logit_predictor.forward(base_model_representations).reshape((base_batch.size()))
labels = base_batch.get_training_labels()

# base batch always has labels, but for unlabeled elements these labels are meaningless and is_labeled_mask is zero
cross_entropies = self.bce(logits, base_batch.labels)
cross_entropies = self.bce(logits, labels)
probabilities = torch.sigmoid(logits)
entropies = self.bce(logits, probabilities)

return base_batch.is_labeled_mask * cross_entropies + (1 - base_batch.is_labeled_mask) * entropies
return base_batch.get_is_labeled_mask() * cross_entropies + (1 - base_batch.get_is_labeled_mask()) * entropies

# I don't like implicit forward!!
def forward(self):
Expand Down Expand Up @@ -402,7 +403,7 @@ def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model
# labels are 1 for artifact, 0 otherwise. We convert to +1 if normal, -1 if artifact
# DeepSAD assumes most unlabeled data are normal and so the unlabeled loss is identical to the normal loss, that is,
# squared Euclidean distance from the centroid
signs = (1 - 2 * base_batch.labels) * base_batch.is_labeled_mask + 1 * (1 - base_batch.is_labeled_mask)
signs = (1 - 2 * base_batch.get_training_labels()) * base_batch.get_is_labeled_mask() + 1 * (1 - base_batch.get_is_labeled_mask())

# distance squared for normal and unlabeled, inverse distance squared for artifact
return dist_squared ** signs
Expand Down Expand Up @@ -437,9 +438,9 @@ def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model
min_dist_squared_b = torch.min(dist_squared_bc, dim=-1).values

# closest centroid with correct label is labeled, otherwise just the closest centroid
labeled_losses_b = (base_batch.labels * artifact_dist_squared_b + (1 - base_batch.labels) * normal_dist_squared_b)
labeled_losses_b = (base_batch.get_training_labels() * artifact_dist_squared_b + (1 - base_batch.get_training_labels()) * normal_dist_squared_b)
unlabeled_losses_b = min_dist_squared_b
embedding_centroid_losses_b = base_batch.is_labeled_mask * labeled_losses_b + (1 - base_batch.is_labeled_mask) * unlabeled_losses_b
embedding_centroid_losses_b = base_batch.get_is_labeled_mask() * labeled_losses_b + (1 - base_batch.get_is_labeled_mask()) * unlabeled_losses_b

# average distance between centroids
centroid_seps_cce = torch.unsqueeze(self.centroids_ce, dim=0) - torch.unsqueeze(self.centroids_ce, dim=1)
Expand Down Expand Up @@ -562,7 +563,7 @@ def learn_base_model(base_model: BaseModel, dataset: BaseDataset, learning_metho
loss = torch.sum((weights * losses) + alt_count_losses)

classification_logits = classifier_on_top.forward(representations.detach()).reshape(batch.size())
classification_losses = classifier_bce(classification_logits, batch.labels)
classification_losses = classifier_bce(classification_logits, batch.get_training_labels())
classification_loss = torch.sum(batch.get_is_labeled_mask() * weights * classification_losses)
classifier_metrics.record_losses(classification_losses.detach(), batch, batch.get_is_labeled_mask() * weights)

Expand Down Expand Up @@ -604,11 +605,11 @@ def record_embeddings(base_model: BaseModel, loader, summary_writer: SummaryWrit
ref_alt_seq_embeddings = ref_alt_seq_embeddings.cpu()

labels = [("artifact" if label > 0.5 else "non-artifact") if is_labeled > 0.5 else "unlabeled" for (label, is_labeled) in
zip(batch.labels.tolist(), batch.get_is_labeled_mask().tolist())]
zip(batch.get_training_labels().tolist(), batch.get_is_labeled_mask().tolist())]
for (metrics, embeddings) in [(embedding_metrics, representations), (ref_alt_seq_metrics, ref_alt_seq_embeddings)]:
metrics.label_metadata.extend(labels)
metrics.correct_metadata.extend(["unknown"] * batch.size())
metrics.type_metadata.extend([Variation(idx).name for idx in batch.variant_types()])
metrics.type_metadata.extend([Variation(idx).name for idx in batch.get_variant_types().tolist()])
alt_count_strings = [str(round_up_to_nearest_three(min(MAX_COUNT, ac))) for ac in batch.get_alt_counts().tolist()]
metrics.truncated_count_metadata.extend(alt_count_strings)
metrics.representations.append(embeddings)
Expand Down
1 change: 0 additions & 1 deletion permutect/architecture/normal_artifact_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, num_samples: int):

def forward(self, tumor_alt_1d: torch.Tensor, tumor_ref_1d: torch.Tensor, normal_alt_1d: torch.Tensor, normal_ref_1d: torch.Tensor):
if torch.sum(normal_alt_1d) < 1: # shortcut if no normal alts in the whole batch
print("debug, no normal alts in batch")
return -9999 * torch.ones_like(tumor_alt_1d)
batch_size = len(tumor_alt_1d)
tumor_fractions_2d, normal_fractions_2d = self.get_tumor_and_normal_fraction(batch_size, self.num_samples)
Expand Down
Loading

0 comments on commit 00eda79

Please sign in to comment.