Skip to content

Commit

Permalink
Mixed count base batches (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbenjamin authored Nov 21, 2024
1 parent b7d66d0 commit dd74395
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 94 deletions.
45 changes: 20 additions & 25 deletions permutect/architecture/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def __init__(self, params: BaseModelParameters, num_read_features: int, num_info
self._dtype = DEFAULT_GPU_FLOAT if device != torch.device("cpu") else DEFAULT_CPU_FLOAT
self._ref_sequence_length = ref_sequence_length
self._params = params
self.alt_downsample = params.alt_downsample

# embeddings of reads, info, and reference sequence prior to the transformer layers
self.read_embedding = MLP([num_read_features] + params.read_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p)
Expand All @@ -130,7 +129,6 @@ def __init__(self, params: BaseModelParameters, num_read_features: int, num_info
assert embedding_dim % params.num_transformer_heads == 0

self.ref_alt_reads_encoder = make_gated_ref_alt_mlp_encoder(embedding_dim, params)
self.alt_encoder = make_gated_mlp_encoder(embedding_dim, params)

# after encoding alt reads (along with info and ref seq embeddings and with self-attention to ref reads)
# pass through another MLP
Expand Down Expand Up @@ -163,35 +161,31 @@ def forward(self, batch: BaseBatch):
# so, for example, "re" means a 2D tensor with all reads in the batch stacked and "vre" means a 3D tensor indexed
# first by variant within the batch, then the read
def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) -> torch.Tensor:
ref_count, alt_count = batch.ref_count, batch.alt_count
total_ref, total_alt = ref_count * batch.size(), alt_count * batch.size()
ref_counts, alt_counts = batch.ref_counts, batch.alt_counts
total_ref, total_alt = torch.sum(ref_counts).item(), torch.sum(alt_counts).item()

read_embeddings_re = self.read_embedding.forward(batch.get_reads_2d().to(dtype=self._dtype))
info_embeddings_ve = self.info_embedding.forward(batch.get_info_2d().to(dtype=self._dtype))
ref_seq_embeddings_ve = self.ref_seq_cnn(batch.get_ref_sequences_2d().to(dtype=self._dtype))
info_and_seq_ve = torch.hstack((info_embeddings_ve, ref_seq_embeddings_ve))
info_and_seq_re = torch.vstack((torch.repeat_interleave(info_and_seq_ve, ref_count, dim=0),
torch.repeat_interleave(info_and_seq_ve, alt_count, dim=0)))
info_and_seq_re = torch.vstack((torch.repeat_interleave(info_and_seq_ve, repeats=ref_counts, dim=0),
torch.repeat_interleave(info_and_seq_ve, repeats=alt_counts, dim=0)))
reads_info_seq_re = torch.hstack((read_embeddings_re, info_and_seq_re))
ref_reads_info_seq_vre = None if total_ref == 0 else reads_info_seq_re[:total_ref].reshape(batch.size(), ref_count, -1)
alt_reads_info_seq_vre = reads_info_seq_re[total_ref:].reshape(batch.size(), alt_count, -1)

if self.alt_downsample < alt_count:
alt_read_indices = torch.randperm(alt_count)[:self.alt_downsample]
alt_reads_info_seq_vre = alt_reads_info_seq_vre[:, alt_read_indices, :] # downsample only along the middle (read) dimension
alt_count = self.alt_downsample
total_alt = batch.size() * self.alt_downsample
# TODO: might be a bug if every datum in batch has zero ref reads?
ref_reads_info_seq_re = reads_info_seq_re[:total_ref]
alt_reads_info_seq_re = reads_info_seq_re[total_ref:]

# undo some of the above rearrangement
# TODO: make sure it handles ref count = 0 case
transformed_ref_re, transformed_alt_re = self.ref_alt_reads_encoder.forward(ref_reads_info_seq_re, alt_reads_info_seq_re, ref_counts, alt_counts)

transformed_ref_vre, transformed_alt_vre = (None, self.alt_encoder(alt_reads_info_seq_vre)) if total_ref == 0 else \
self.ref_alt_reads_encoder(ref_reads_info_seq_vre, alt_reads_info_seq_vre)
alt_weights_r = 1 + weight_range * (1 - 2 * torch.rand(total_alt, device=self._device, dtype=self._dtype))

alt_weights_vr = 1 + weight_range * (1 - 2 * torch.rand(batch.size(), alt_count, device=self._device, dtype=self._dtype))
alt_wt_sums = torch.sum(alt_weights_vr, dim=1, keepdim=True)
# normalized so read weights within each variant sum to 1 and add dummy e dimension for broadcasting the multiply below
normalized_alt_weights_vr1 = (alt_weights_vr / alt_wt_sums).reshape(batch.size(), alt_count, 1)
alt_means_ve = torch.sum(transformed_alt_vre * normalized_alt_weights_vr1, dim=1)
# normalize so read weights within each variant sum to 1
alt_wt_sums_v = utils.sums_over_rows(alt_weights_r, alt_counts)
normalized_alt_weights_r = alt_weights_r / torch.repeat_interleave(alt_wt_sums_v, repeats=alt_counts, dim=0)

alt_means_ve = utils.sums_over_rows(transformed_alt_re * normalized_alt_weights_r[:,None], alt_counts)

result_ve = self.aggregation.forward(alt_means_ve)

Expand Down Expand Up @@ -371,8 +365,10 @@ def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model
alt_vre = torch.cat((alt_representations_vre, random_alt_seeds_vre), dim=-1)
ref_vre = torch.cat((ref_representations_vre, random_ref_seeds_vre), dim=-1) if ref_count > 0 else None

decoded_alt_vre = self.alt_decoder(alt_vre)
decoded_ref_vre = self.ref_decoder(ref_vre) if ref_count > 0 else None
# TODO: update these to reflect mixed-count batches. Gated MLPs now take inputs flattened over batch dimension
# TODO: and have an extra input of ref and alt read counts
decoded_alt_vre = self.alt_decoder.forward(alt_vre)
decoded_ref_vre = self.ref_decoder.forward(ref_vre) if ref_count > 0 else None

decoded_alt_re = torch.reshape(decoded_alt_vre, (var_count * alt_count, -1))
decoded_ref_re = torch.reshape(decoded_ref_vre, (var_count * ref_count, -1)) if ref_count > 0 else None
Expand Down Expand Up @@ -506,11 +502,10 @@ def learn_base_model(base_model: BaseModel, dataset: BaseDataset, learning_metho
.to(device=base_model._device, dtype=base_model._dtype)
classifier_bce = torch.nn.BCEWithLogitsLoss(reduction='none')

# TODO: fused = is_cuda?
classifier_optimizer = torch.optim.AdamW(classifier_on_top.parameters(),
lr=training_params.learning_rate,
weight_decay=training_params.weight_decay,
fused=True)
fused=is_cuda)
classifier_metrics = LossMetrics()

validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold
Expand Down
97 changes: 54 additions & 43 deletions permutect/architecture/gated_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import torch
from torch import nn

from permutect import utils


class GatedMLPBlock(nn.Module):
"""
Expand Down Expand Up @@ -66,19 +68,22 @@ def __init__(self, d_model: int, d_ffn: int):
# *gMLP* block as a replacement for the [Transformer Layer](../models.html#Encoder).
self.size = d_model

def forward(self, x_bre: torch.Tensor):
# X is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group
# that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups
# In other words, all the reads of a batch are flattened together in X -- the batch information is in counts
def forward(self, x_re: torch.Tensor, counts: torch.IntTensor):
"""
* `x_bre` is the input read embedding tensor of shape Batch x Reads x Embedding
"""
# Norm, projection to d_ffn, and activation $Z = \sigma(XU)$
z_brd = self.activation(self.proj1(self.norm(x_bre)))
z_rd = self.activation(self.proj1(self.norm(x_re)))
# Spacial Gating Unit $\tilde{Z} = s(Z)$
gated_brd = self.sgu(z_brd)
gated_rd = self.sgu.forward(z_rd, counts)
# Final projection $Y = \tilde{Z}V$ back to embedding dimension
gated_bre = self.proj2(gated_brd)
gated_re = self.proj2(gated_rd)

# Add the shortcut connection
return x_bre + gated_bre
return x_re + gated_re


class SpacialGatingUnit(nn.Module):
Expand All @@ -105,24 +110,23 @@ def __init__(self, d_z: int):
# Normalization layer before applying $f_{W,b}(\cdot)$
self.norm = nn.LayerNorm([d_z // 2])
# Weight $W$ in $f_{W,b}(\cdot)$.
#

# TODO: shouldn't alpha and beta be element-by-element???
self.alpha = nn.Parameter(torch.tensor(0.01))
self.beta = nn.Parameter(torch.tensor(0.01))

def forward(self, z_brd: torch.Tensor):
"""
* `z_brd` is the input tensor of shape Batch x Reads x Dimension
`[seq_len, batch_size, d_z]`
"""

# Z is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group
# that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups
def forward(self, z_rd: torch.Tensor, counts: torch.IntTensor):
# Split $Z$ into $Z_1$ and $Z_2$ over the hidden dimension and normalize $Z_2$ before $f_{W,b}(\cdot)$
z1_brd, z2_brd = torch.chunk(z_brd, 2, dim=-1)
z2_brd = self.norm(z2_brd)
z1_rd, z2_rd = torch.chunk(z_rd, 2, dim=-1)
z2_rd = self.norm(z2_rd)

z2_brd = 1 + self.alpha * z2_brd + torch.mean(z2_brd, dim=1, keepdim=True)
# TODO: self.beta needs to multiply the mean field here!!!
z2_rd = 1 + self.alpha * z2_rd + utils.means_over_rows(z2_rd, counts, keepdim=True)

# $Z_1 \odot f_{W,b}(Z_2)$
return z1_brd * z2_brd
return z1_rd * z2_rd


class GatedMLP(nn.Module):
Expand All @@ -131,9 +135,11 @@ def __init__(self, d_model: int, d_ffn: int, num_blocks: int):

self.blocks = nn.ModuleList([GatedMLPBlock(d_model, d_ffn) for _ in range(num_blocks)])

def forward(self, x):
# X is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group
# that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups
def forward(self, x, counts):
for block in self.blocks:
x = block(x)
x = block.forward(x, counts)
return x


Expand Down Expand Up @@ -166,22 +172,22 @@ def __init__(self, d_model: int, d_ffn: int):
# *gMLP* block as a replacement for the [Transformer Layer](../models.html#Encoder).
self.size = d_model

def forward(self, ref_bre: torch.Tensor, alt_bre: torch.Tensor):
def forward(self, ref_re: torch.Tensor, alt_re: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor):
"""
* `x_bre` is the input read embedding tensor of shape Batch x Reads x Embedding
"""
# Norm, projection to d_ffn, and activation $Z = \sigma(XU)$
zref_brd = self.activation(self.proj1_ref(self.norm(ref_bre)))
zalt_brd = self.activation(self.proj1_alt(self.norm(alt_bre)))
zref_rd = self.activation(self.proj1_ref(self.norm(ref_re)))
zalt_rd = self.activation(self.proj1_alt(self.norm(alt_re)))

# Spacial Gating Unit $\tilde{Z} = s(Z)$
gated_ref_brd, gated_alt_brd = self.sgu(zref_brd, zalt_brd)
gated_ref_rd, gated_alt_rd = self.sgu.forward(zref_rd, zalt_rd, ref_counts, alt_counts)
# Final projection $Y = \tilde{Z}V$ back to embedding dimension
gated_ref_bre = self.proj2_ref(gated_ref_brd)
gated_alt_bre = self.proj2_alt(gated_alt_brd)
gated_ref_re = self.proj2_ref(gated_ref_rd)
gated_alt_re = self.proj2_alt(gated_alt_rd)

# Add the shortcut connection
return ref_bre + gated_ref_bre, alt_bre + gated_alt_bre
return ref_re + gated_ref_re, alt_re + gated_alt_re


class SpacialGatingUnitRefAlt(nn.Module):
Expand All @@ -196,36 +202,41 @@ def __init__(self, d_z: int):
# Normalization layer before applying $f_{W,b}(\cdot)$
self.norm = nn.LayerNorm([d_z // 2])
# Weight $W$ in $f_{W,b}(\cdot)$.
#

# TODO: maybe let these parameters be element-by-element vectors?
self.alpha_ref = nn.Parameter(torch.tensor(0.01))
self.alpha_alt = nn.Parameter(torch.tensor(0.01))
self.beta_ref = nn.Parameter(torch.tensor(0.01))
self.beta_alt = nn.Parameter(torch.tensor(0.01))

self.gamma = nn.Parameter(torch.tensor(0.01))

def forward(self, zref_brd: torch.Tensor, zalt_brd: torch.Tensor):
"""
* `z_brd` is the input tensor of shape Batch x Reads x Dimension
`[seq_len, batch_size, d_z]`
"""
# regularizer / sort of imputed value for when there are no ref counts
self.ref_regularizer = nn.Parameter(0.1 * torch.ones(d_z // 2))
self.regularizer_weight = nn.Parameter(torch.tensor(0.1))

def forward(self, zref_rd: torch.Tensor, zalt_rd: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor):

# Split $Z$ into $Z_1$ and $Z_2$ over the hidden dimension and normalize $Z_2$ before $f_{W,b}(\cdot)$
z1_ref_brd, z2_ref_brd = torch.chunk(zref_brd, 2, dim=-1)
z1_alt_brd, z2_alt_brd = torch.chunk(zalt_brd, 2, dim=-1)
z2_ref_brd = self.norm(z2_ref_brd)
z2_alt_brd = self.norm(z2_alt_brd)
z1_ref_rd, z2_ref_rd = torch.chunk(zref_rd, 2, dim=-1)
z1_alt_rd, z2_alt_rd = torch.chunk(zalt_rd, 2, dim=-1)
z2_ref_rd = self.norm(z2_ref_rd)
z2_alt_rd = self.norm(z2_alt_rd)

# these are means by variant -- need repeat_interleave to make them by-read
ref_mean_field_vd = utils.means_over_rows_with_regularizer(z2_ref_rd, ref_counts, self.ref_regularizer, self.regularizer_weight)
alt_mean_field_vd = utils.means_over_rows(z2_alt_rd, alt_counts)

ref_mean_field_brd = torch.mean(z2_ref_brd, dim=1, keepdim=True)
alt_mean_field_brd = torch.mean(z2_alt_brd, dim=1, keepdim=True)
ref_mean_field_on_ref_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=ref_counts)
ref_mean_field_on_alt_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=alt_counts)
alt_mean_field_on_alt_rd = torch.repeat_interleave(alt_mean_field_vd, dim=0, repeats=alt_counts)

# same as above except now there is an additional term for the ref mean field influence on alt
# maybe later also let alt mean field influence ref
z2_ref_brd = 1 + self.alpha_ref * z2_ref_brd + self.beta_ref * ref_mean_field_brd
z2_alt_brd = 1 + self.alpha_alt * z2_alt_brd + self.beta_alt * alt_mean_field_brd + self.gamma * ref_mean_field_brd
z2_ref_rd = 1 + self.alpha_ref * z2_ref_rd + self.beta_ref * ref_mean_field_on_ref_rd
z2_alt_rd = 1 + self.alpha_alt * z2_alt_rd + self.beta_alt * alt_mean_field_on_alt_rd + self.gamma * ref_mean_field_on_alt_rd

# $Z_1 \odot f_{W,b}(Z_2)$
return z1_ref_brd * z2_ref_brd, z1_alt_brd * z2_alt_brd
return z1_ref_rd * z2_ref_rd, z1_alt_rd * z2_alt_rd


class GatedRefAltMLP(nn.Module):
Expand All @@ -234,7 +245,7 @@ def __init__(self, d_model: int, d_ffn: int, num_blocks: int):

self.blocks = nn.ModuleList([GatedRefAltMLPBlock(d_model, d_ffn) for _ in range(num_blocks)])

def forward(self, ref, alt):
def forward(self, ref, alt, ref_counts, alt_counts):
for block in self.blocks:
ref, alt = block(ref, alt)
ref, alt = block(ref, alt, ref_counts, alt_counts)
return ref, alt
24 changes: 8 additions & 16 deletions permutect/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ def __init__(self, data_in_ram: Iterable[BaseDatum] = None, data_tarfile=None, n
self._data = RaggedMmap(self._memory_map_dir.name)
self._memory_map_mode = True

# keys = (ref read count, alt read count) tuples; values = list of indices
# this is used in the batch sampler to make same-shape batches
self.labeled_indices_by_count = [defaultdict(list) for _ in range(num_folds)]
self.unlabeled_indices_by_count = [defaultdict(list) for _ in range(num_folds)]
self.indices_by_fold = [[] for _ in range(num_folds)]

# totals by count, then by label -- ARTIFACT, VARIANT, UNLABELED, then by variant type
# variant type is done as a 1D np array parallel to the one-hot encoding of variant type
Expand All @@ -87,7 +85,7 @@ def __init__(self, data_in_ram: Iterable[BaseDatum] = None, data_tarfile=None, n

fold = n % num_folds
counts = (len(datum.reads_2d) - datum.alt_count, datum.alt_count)
(self.unlabeled_indices_by_count if datum.label == Label.UNLABELED else self.labeled_indices_by_count)[fold][counts].append(n)
self.indices_by_fold[fold].append(n)

one_hot = datum.variant_type_one_hot()
self.totals[ALL_COUNTS_SENTINEL][datum.label] += one_hot
Expand Down Expand Up @@ -191,30 +189,24 @@ def chunk(lis, chunk_size):
return [lis[i:i + chunk_size] for i in range(0, len(lis), chunk_size)]


# make batches that have a single value for ref, alt counts within batches. Labeled and unlabeled data are mixed.
# Labeled and unlabeled data are mixed.
# the artifact model handles weighting the losses to compensate for class imbalance between supervised and unsupervised
# thus the sampler is not responsible for balancing the data
class SemiSupervisedBatchSampler(Sampler):
def __init__(self, dataset: BaseDataset, batch_size, folds_to_use: List[int]):
# combine the index maps of all relevant folds
self.indices_by_count = defaultdict(list)
self.indices_to_use = []

for fold in folds_to_use:
new_labeled = dataset.labeled_indices_by_count[fold]
new_unlabeled = dataset.unlabeled_indices_by_count[fold]
for count, indices in new_labeled.items():
self.indices_by_count[count].extend(indices)
for count, indices in new_unlabeled.items():
self.indices_by_count[count].extend(indices)
self.indices_to_use.extend(dataset.indices_by_fold[fold])

self.batch_size = batch_size
self.num_batches = sum(math.ceil(len(indices) // self.batch_size) for indices in self.indices_by_count.values())
self.num_batches = math.ceil(len(self.indices_to_use) // self.batch_size)

def __iter__(self):
batches = [] # list of lists of indices -- each sublist is a batch
for index_list in self.indices_by_count.values():
random.shuffle(index_list)
batches.extend(chunk(index_list, self.batch_size))
random.shuffle(self.indices_to_use)
batches.extend(chunk(self.indices_to_use, self.batch_size))
random.shuffle(batches)

return iter(batches)
Expand Down
Loading

0 comments on commit dd74395

Please sign in to comment.