diff --git a/permutect/architecture/base_model.py b/permutect/architecture/base_model.py index a5d20ace..b228f0f7 100644 --- a/permutect/architecture/base_model.py +++ b/permutect/architecture/base_model.py @@ -119,7 +119,6 @@ def __init__(self, params: BaseModelParameters, num_read_features: int, num_info self._dtype = DEFAULT_GPU_FLOAT if device != torch.device("cpu") else DEFAULT_CPU_FLOAT self._ref_sequence_length = ref_sequence_length self._params = params - self.alt_downsample = params.alt_downsample # embeddings of reads, info, and reference sequence prior to the transformer layers self.read_embedding = MLP([num_read_features] + params.read_layers, batch_normalize=params.batch_normalize, dropout_p=params.dropout_p) @@ -130,7 +129,6 @@ def __init__(self, params: BaseModelParameters, num_read_features: int, num_info assert embedding_dim % params.num_transformer_heads == 0 self.ref_alt_reads_encoder = make_gated_ref_alt_mlp_encoder(embedding_dim, params) - self.alt_encoder = make_gated_mlp_encoder(embedding_dim, params) # after encoding alt reads (along with info and ref seq embeddings and with self-attention to ref reads) # pass through another MLP @@ -163,35 +161,31 @@ def forward(self, batch: BaseBatch): # so, for example, "re" means a 2D tensor with all reads in the batch stacked and "vre" means a 3D tensor indexed # first by variant within the batch, then the read def calculate_representations(self, batch: BaseBatch, weight_range: float = 0) -> torch.Tensor: - ref_count, alt_count = batch.ref_count, batch.alt_count - total_ref, total_alt = ref_count * batch.size(), alt_count * batch.size() + ref_counts, alt_counts = batch.ref_counts, batch.alt_counts + total_ref, total_alt = torch.sum(ref_counts).item(), torch.sum(alt_counts).item() read_embeddings_re = self.read_embedding.forward(batch.get_reads_2d().to(dtype=self._dtype)) info_embeddings_ve = self.info_embedding.forward(batch.get_info_2d().to(dtype=self._dtype)) ref_seq_embeddings_ve = self.ref_seq_cnn(batch.get_ref_sequences_2d().to(dtype=self._dtype)) info_and_seq_ve = torch.hstack((info_embeddings_ve, ref_seq_embeddings_ve)) - info_and_seq_re = torch.vstack((torch.repeat_interleave(info_and_seq_ve, ref_count, dim=0), - torch.repeat_interleave(info_and_seq_ve, alt_count, dim=0))) + info_and_seq_re = torch.vstack((torch.repeat_interleave(info_and_seq_ve, repeats=ref_counts, dim=0), + torch.repeat_interleave(info_and_seq_ve, repeats=alt_counts, dim=0))) reads_info_seq_re = torch.hstack((read_embeddings_re, info_and_seq_re)) - ref_reads_info_seq_vre = None if total_ref == 0 else reads_info_seq_re[:total_ref].reshape(batch.size(), ref_count, -1) - alt_reads_info_seq_vre = reads_info_seq_re[total_ref:].reshape(batch.size(), alt_count, -1) - if self.alt_downsample < alt_count: - alt_read_indices = torch.randperm(alt_count)[:self.alt_downsample] - alt_reads_info_seq_vre = alt_reads_info_seq_vre[:, alt_read_indices, :] # downsample only along the middle (read) dimension - alt_count = self.alt_downsample - total_alt = batch.size() * self.alt_downsample + # TODO: might be a bug if every datum in batch has zero ref reads? + ref_reads_info_seq_re = reads_info_seq_re[:total_ref] + alt_reads_info_seq_re = reads_info_seq_re[total_ref:] - # undo some of the above rearrangement + # TODO: make sure it handles ref count = 0 case + transformed_ref_re, transformed_alt_re = self.ref_alt_reads_encoder.forward(ref_reads_info_seq_re, alt_reads_info_seq_re, ref_counts, alt_counts) - transformed_ref_vre, transformed_alt_vre = (None, self.alt_encoder(alt_reads_info_seq_vre)) if total_ref == 0 else \ - self.ref_alt_reads_encoder(ref_reads_info_seq_vre, alt_reads_info_seq_vre) + alt_weights_r = 1 + weight_range * (1 - 2 * torch.rand(total_alt, device=self._device, dtype=self._dtype)) - alt_weights_vr = 1 + weight_range * (1 - 2 * torch.rand(batch.size(), alt_count, device=self._device, dtype=self._dtype)) - alt_wt_sums = torch.sum(alt_weights_vr, dim=1, keepdim=True) - # normalized so read weights within each variant sum to 1 and add dummy e dimension for broadcasting the multiply below - normalized_alt_weights_vr1 = (alt_weights_vr / alt_wt_sums).reshape(batch.size(), alt_count, 1) - alt_means_ve = torch.sum(transformed_alt_vre * normalized_alt_weights_vr1, dim=1) + # normalize so read weights within each variant sum to 1 + alt_wt_sums_v = utils.sums_over_rows(alt_weights_r, alt_counts) + normalized_alt_weights_r = alt_weights_r / torch.repeat_interleave(alt_wt_sums_v, repeats=alt_counts, dim=0) + + alt_means_ve = utils.sums_over_rows(transformed_alt_re * normalized_alt_weights_r[:,None], alt_counts) result_ve = self.aggregation.forward(alt_means_ve) @@ -371,8 +365,10 @@ def loss_function(self, base_model: BaseModel, base_batch: BaseBatch, base_model alt_vre = torch.cat((alt_representations_vre, random_alt_seeds_vre), dim=-1) ref_vre = torch.cat((ref_representations_vre, random_ref_seeds_vre), dim=-1) if ref_count > 0 else None - decoded_alt_vre = self.alt_decoder(alt_vre) - decoded_ref_vre = self.ref_decoder(ref_vre) if ref_count > 0 else None + # TODO: update these to reflect mixed-count batches. Gated MLPs now take inputs flattened over batch dimension + # TODO: and have an extra input of ref and alt read counts + decoded_alt_vre = self.alt_decoder.forward(alt_vre) + decoded_ref_vre = self.ref_decoder.forward(ref_vre) if ref_count > 0 else None decoded_alt_re = torch.reshape(decoded_alt_vre, (var_count * alt_count, -1)) decoded_ref_re = torch.reshape(decoded_ref_vre, (var_count * ref_count, -1)) if ref_count > 0 else None @@ -506,11 +502,10 @@ def learn_base_model(base_model: BaseModel, dataset: BaseDataset, learning_metho .to(device=base_model._device, dtype=base_model._dtype) classifier_bce = torch.nn.BCEWithLogitsLoss(reduction='none') - # TODO: fused = is_cuda? classifier_optimizer = torch.optim.AdamW(classifier_on_top.parameters(), lr=training_params.learning_rate, weight_decay=training_params.weight_decay, - fused=True) + fused=is_cuda) classifier_metrics = LossMetrics() validation_fold_to_use = (dataset.num_folds - 1) if validation_fold is None else validation_fold diff --git a/permutect/architecture/gated_mlp.py b/permutect/architecture/gated_mlp.py index 15fba26a..e862c569 100644 --- a/permutect/architecture/gated_mlp.py +++ b/permutect/architecture/gated_mlp.py @@ -23,6 +23,8 @@ import torch from torch import nn +from permutect import utils + class GatedMLPBlock(nn.Module): """ @@ -66,19 +68,22 @@ def __init__(self, d_model: int, d_ffn: int): # *gMLP* block as a replacement for the [Transformer Layer](../models.html#Encoder). self.size = d_model - def forward(self, x_bre: torch.Tensor): + # X is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group + # that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups + # In other words, all the reads of a batch are flattened together in X -- the batch information is in counts + def forward(self, x_re: torch.Tensor, counts: torch.IntTensor): """ * `x_bre` is the input read embedding tensor of shape Batch x Reads x Embedding """ # Norm, projection to d_ffn, and activation $Z = \sigma(XU)$ - z_brd = self.activation(self.proj1(self.norm(x_bre))) + z_rd = self.activation(self.proj1(self.norm(x_re))) # Spacial Gating Unit $\tilde{Z} = s(Z)$ - gated_brd = self.sgu(z_brd) + gated_rd = self.sgu.forward(z_rd, counts) # Final projection $Y = \tilde{Z}V$ back to embedding dimension - gated_bre = self.proj2(gated_brd) + gated_re = self.proj2(gated_rd) # Add the shortcut connection - return x_bre + gated_bre + return x_re + gated_re class SpacialGatingUnit(nn.Module): @@ -105,24 +110,23 @@ def __init__(self, d_z: int): # Normalization layer before applying $f_{W,b}(\cdot)$ self.norm = nn.LayerNorm([d_z // 2]) # Weight $W$ in $f_{W,b}(\cdot)$. - # + + # TODO: shouldn't alpha and beta be element-by-element??? self.alpha = nn.Parameter(torch.tensor(0.01)) self.beta = nn.Parameter(torch.tensor(0.01)) - def forward(self, z_brd: torch.Tensor): - """ - * `z_brd` is the input tensor of shape Batch x Reads x Dimension - `[seq_len, batch_size, d_z]` - """ - + # Z is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group + # that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups + def forward(self, z_rd: torch.Tensor, counts: torch.IntTensor): # Split $Z$ into $Z_1$ and $Z_2$ over the hidden dimension and normalize $Z_2$ before $f_{W,b}(\cdot)$ - z1_brd, z2_brd = torch.chunk(z_brd, 2, dim=-1) - z2_brd = self.norm(z2_brd) + z1_rd, z2_rd = torch.chunk(z_rd, 2, dim=-1) + z2_rd = self.norm(z2_rd) - z2_brd = 1 + self.alpha * z2_brd + torch.mean(z2_brd, dim=1, keepdim=True) + # TODO: self.beta needs to multiply the mean field here!!! + z2_rd = 1 + self.alpha * z2_rd + utils.means_over_rows(z2_rd, counts, keepdim=True) # $Z_1 \odot f_{W,b}(Z_2)$ - return z1_brd * z2_brd + return z1_rd * z2_rd class GatedMLP(nn.Module): @@ -131,9 +135,11 @@ def __init__(self, d_model: int, d_ffn: int, num_blocks: int): self.blocks = nn.ModuleList([GatedMLPBlock(d_model, d_ffn) for _ in range(num_blocks)]) - def forward(self, x): + # X is 2D, counts are the numbers of elements in each consecutive group of rows that form a self-attention group + # that is, is X has 10 rows and counts = [2,3,5], elements 0-1, 2-4, and 5-9 form independent self-attention groups + def forward(self, x, counts): for block in self.blocks: - x = block(x) + x = block.forward(x, counts) return x @@ -166,22 +172,22 @@ def __init__(self, d_model: int, d_ffn: int): # *gMLP* block as a replacement for the [Transformer Layer](../models.html#Encoder). self.size = d_model - def forward(self, ref_bre: torch.Tensor, alt_bre: torch.Tensor): + def forward(self, ref_re: torch.Tensor, alt_re: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor): """ * `x_bre` is the input read embedding tensor of shape Batch x Reads x Embedding """ # Norm, projection to d_ffn, and activation $Z = \sigma(XU)$ - zref_brd = self.activation(self.proj1_ref(self.norm(ref_bre))) - zalt_brd = self.activation(self.proj1_alt(self.norm(alt_bre))) + zref_rd = self.activation(self.proj1_ref(self.norm(ref_re))) + zalt_rd = self.activation(self.proj1_alt(self.norm(alt_re))) # Spacial Gating Unit $\tilde{Z} = s(Z)$ - gated_ref_brd, gated_alt_brd = self.sgu(zref_brd, zalt_brd) + gated_ref_rd, gated_alt_rd = self.sgu.forward(zref_rd, zalt_rd, ref_counts, alt_counts) # Final projection $Y = \tilde{Z}V$ back to embedding dimension - gated_ref_bre = self.proj2_ref(gated_ref_brd) - gated_alt_bre = self.proj2_alt(gated_alt_brd) + gated_ref_re = self.proj2_ref(gated_ref_rd) + gated_alt_re = self.proj2_alt(gated_alt_rd) # Add the shortcut connection - return ref_bre + gated_ref_bre, alt_bre + gated_alt_bre + return ref_re + gated_ref_re, alt_re + gated_alt_re class SpacialGatingUnitRefAlt(nn.Module): @@ -196,36 +202,41 @@ def __init__(self, d_z: int): # Normalization layer before applying $f_{W,b}(\cdot)$ self.norm = nn.LayerNorm([d_z // 2]) # Weight $W$ in $f_{W,b}(\cdot)$. - # + + # TODO: maybe let these parameters be element-by-element vectors? self.alpha_ref = nn.Parameter(torch.tensor(0.01)) self.alpha_alt = nn.Parameter(torch.tensor(0.01)) self.beta_ref = nn.Parameter(torch.tensor(0.01)) self.beta_alt = nn.Parameter(torch.tensor(0.01)) - self.gamma = nn.Parameter(torch.tensor(0.01)) - def forward(self, zref_brd: torch.Tensor, zalt_brd: torch.Tensor): - """ - * `z_brd` is the input tensor of shape Batch x Reads x Dimension - `[seq_len, batch_size, d_z]` - """ + # regularizer / sort of imputed value for when there are no ref counts + self.ref_regularizer = nn.Parameter(0.1 * torch.ones(d_z // 2)) + self.regularizer_weight = nn.Parameter(torch.tensor(0.1)) + + def forward(self, zref_rd: torch.Tensor, zalt_rd: torch.Tensor, ref_counts: torch.IntTensor, alt_counts: torch.IntTensor): # Split $Z$ into $Z_1$ and $Z_2$ over the hidden dimension and normalize $Z_2$ before $f_{W,b}(\cdot)$ - z1_ref_brd, z2_ref_brd = torch.chunk(zref_brd, 2, dim=-1) - z1_alt_brd, z2_alt_brd = torch.chunk(zalt_brd, 2, dim=-1) - z2_ref_brd = self.norm(z2_ref_brd) - z2_alt_brd = self.norm(z2_alt_brd) + z1_ref_rd, z2_ref_rd = torch.chunk(zref_rd, 2, dim=-1) + z1_alt_rd, z2_alt_rd = torch.chunk(zalt_rd, 2, dim=-1) + z2_ref_rd = self.norm(z2_ref_rd) + z2_alt_rd = self.norm(z2_alt_rd) + + # these are means by variant -- need repeat_interleave to make them by-read + ref_mean_field_vd = utils.means_over_rows_with_regularizer(z2_ref_rd, ref_counts, self.ref_regularizer, self.regularizer_weight) + alt_mean_field_vd = utils.means_over_rows(z2_alt_rd, alt_counts) - ref_mean_field_brd = torch.mean(z2_ref_brd, dim=1, keepdim=True) - alt_mean_field_brd = torch.mean(z2_alt_brd, dim=1, keepdim=True) + ref_mean_field_on_ref_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=ref_counts) + ref_mean_field_on_alt_rd = torch.repeat_interleave(ref_mean_field_vd, dim=0, repeats=alt_counts) + alt_mean_field_on_alt_rd = torch.repeat_interleave(alt_mean_field_vd, dim=0, repeats=alt_counts) # same as above except now there is an additional term for the ref mean field influence on alt # maybe later also let alt mean field influence ref - z2_ref_brd = 1 + self.alpha_ref * z2_ref_brd + self.beta_ref * ref_mean_field_brd - z2_alt_brd = 1 + self.alpha_alt * z2_alt_brd + self.beta_alt * alt_mean_field_brd + self.gamma * ref_mean_field_brd + z2_ref_rd = 1 + self.alpha_ref * z2_ref_rd + self.beta_ref * ref_mean_field_on_ref_rd + z2_alt_rd = 1 + self.alpha_alt * z2_alt_rd + self.beta_alt * alt_mean_field_on_alt_rd + self.gamma * ref_mean_field_on_alt_rd # $Z_1 \odot f_{W,b}(Z_2)$ - return z1_ref_brd * z2_ref_brd, z1_alt_brd * z2_alt_brd + return z1_ref_rd * z2_ref_rd, z1_alt_rd * z2_alt_rd class GatedRefAltMLP(nn.Module): @@ -234,7 +245,7 @@ def __init__(self, d_model: int, d_ffn: int, num_blocks: int): self.blocks = nn.ModuleList([GatedRefAltMLPBlock(d_model, d_ffn) for _ in range(num_blocks)]) - def forward(self, ref, alt): + def forward(self, ref, alt, ref_counts, alt_counts): for block in self.blocks: - ref, alt = block(ref, alt) + ref, alt = block(ref, alt, ref_counts, alt_counts) return ref, alt diff --git a/permutect/data/base_dataset.py b/permutect/data/base_dataset.py index 1e6da29c..299946e6 100644 --- a/permutect/data/base_dataset.py +++ b/permutect/data/base_dataset.py @@ -63,10 +63,8 @@ def __init__(self, data_in_ram: Iterable[BaseDatum] = None, data_tarfile=None, n self._data = RaggedMmap(self._memory_map_dir.name) self._memory_map_mode = True - # keys = (ref read count, alt read count) tuples; values = list of indices # this is used in the batch sampler to make same-shape batches - self.labeled_indices_by_count = [defaultdict(list) for _ in range(num_folds)] - self.unlabeled_indices_by_count = [defaultdict(list) for _ in range(num_folds)] + self.indices_by_fold = [[] for _ in range(num_folds)] # totals by count, then by label -- ARTIFACT, VARIANT, UNLABELED, then by variant type # variant type is done as a 1D np array parallel to the one-hot encoding of variant type @@ -87,7 +85,7 @@ def __init__(self, data_in_ram: Iterable[BaseDatum] = None, data_tarfile=None, n fold = n % num_folds counts = (len(datum.reads_2d) - datum.alt_count, datum.alt_count) - (self.unlabeled_indices_by_count if datum.label == Label.UNLABELED else self.labeled_indices_by_count)[fold][counts].append(n) + self.indices_by_fold[fold].append(n) one_hot = datum.variant_type_one_hot() self.totals[ALL_COUNTS_SENTINEL][datum.label] += one_hot @@ -191,30 +189,24 @@ def chunk(lis, chunk_size): return [lis[i:i + chunk_size] for i in range(0, len(lis), chunk_size)] -# make batches that have a single value for ref, alt counts within batches. Labeled and unlabeled data are mixed. +# Labeled and unlabeled data are mixed. # the artifact model handles weighting the losses to compensate for class imbalance between supervised and unsupervised # thus the sampler is not responsible for balancing the data class SemiSupervisedBatchSampler(Sampler): def __init__(self, dataset: BaseDataset, batch_size, folds_to_use: List[int]): # combine the index maps of all relevant folds - self.indices_by_count = defaultdict(list) + self.indices_to_use = [] for fold in folds_to_use: - new_labeled = dataset.labeled_indices_by_count[fold] - new_unlabeled = dataset.unlabeled_indices_by_count[fold] - for count, indices in new_labeled.items(): - self.indices_by_count[count].extend(indices) - for count, indices in new_unlabeled.items(): - self.indices_by_count[count].extend(indices) + self.indices_to_use.extend(dataset.indices_by_fold[fold]) self.batch_size = batch_size - self.num_batches = sum(math.ceil(len(indices) // self.batch_size) for indices in self.indices_by_count.values()) + self.num_batches = math.ceil(len(self.indices_to_use) // self.batch_size) def __iter__(self): batches = [] # list of lists of indices -- each sublist is a batch - for index_list in self.indices_by_count.values(): - random.shuffle(index_list) - batches.extend(chunk(index_list, self.batch_size)) + random.shuffle(self.indices_to_use) + batches.extend(chunk(self.indices_to_use, self.batch_size)) random.shuffle(batches) return iter(batches) diff --git a/permutect/data/base_datum.py b/permutect/data/base_datum.py index 05d5a0fc..5e1396b1 100644 --- a/permutect/data/base_datum.py +++ b/permutect/data/base_datum.py @@ -550,7 +550,8 @@ def __init__(self, data: List[BaseDatum]): self._original_list = data self.ref_count = len(data[0].reads_2d) - data[0].alt_count self.alt_count = data[0].alt_count - self.alt_counts = IntTensor([data[0].alt_count for _ in data]) + self.alt_counts = IntTensor([datum.alt_count for datum in data]) + self.ref_counts = IntTensor([len(datum.reads_2d) - datum.alt_count for datum in data]) # for datum in data: # assert (datum.label() != Label.UNLABELED) == self.labeled, "Batch may not mix labeled and unlabeled" @@ -582,6 +583,8 @@ def __init__(self, data: List[BaseDatum]): def pin_memory(self): self.ref_sequences_2d = self.ref_sequences_2d.pin_memory() self.reads_2d = self.reads_2d.pin_memory() + self.alt_counts = self.alt_counts.pin_memory() + self.ref_counts = self.ref_counts.pin_memory() self.info_2d = self.info_2d.pin_memory() self.labels = self.labels.pin_memory() self.is_labeled_mask = self.is_labeled_mask.pin_memory() @@ -598,6 +601,7 @@ def copy_to(self, device, non_blocking): new_batch.is_labeled_mask = self.is_labeled_mask.to(device, non_blocking=non_blocking) new_batch.sources = self.sources.to(device, non_blocking=non_blocking) new_batch.alt_counts = self.alt_counts.to(device, non_blocking=non_blocking) + new_batch.ref_counts = self.ref_counts.to(device, non_blocking=non_blocking) return new_batch def original_list(self): diff --git a/permutect/parameters.py b/permutect/parameters.py index 246aa7f7..3ba61024 100644 --- a/permutect/parameters.py +++ b/permutect/parameters.py @@ -14,7 +14,7 @@ class BaseModelParameters: """ def __init__(self, read_layers: List[int], num_transformer_heads: int, transformer_hidden_dimension: int, num_transformer_layers: int, info_layers: List[int], aggregation_layers: List[int], - ref_seq_layers_strings: List[str], dropout_p: float, reweighting_range: float, batch_normalize: bool = False, alt_downsample: int = 100): + ref_seq_layers_strings: List[str], dropout_p: float, reweighting_range: float, batch_normalize: bool = False): self.read_layers = read_layers self.info_layers = info_layers @@ -26,7 +26,6 @@ def __init__(self, read_layers: List[int], num_transformer_heads: int, transform self.dropout_p = dropout_p self.reweighting_range = reweighting_range self.batch_normalize = batch_normalize - self.alt_downsample = alt_downsample def output_dimension(self): return self.aggregation_layers[-1] @@ -43,10 +42,9 @@ def parse_base_model_params(args) -> BaseModelParameters: dropout_p = getattr(args, constants.DROPOUT_P_NAME) reweighting_range = getattr(args, constants.REWEIGHTING_RANGE_NAME) batch_normalize = getattr(args, constants.BATCH_NORMALIZE_NAME) - alt_downsample = getattr(args, constants.ALT_DOWNSAMPLE_NAME) return BaseModelParameters(read_layers, num_transformer_heads, transformer_hidden_dimension, num_transformer_layers, info_layers, aggregation_layers, ref_seq_layer_strings, dropout_p, - reweighting_range, batch_normalize, alt_downsample) + reweighting_range, batch_normalize) def add_base_model_params_to_parser(parser): @@ -75,8 +73,6 @@ def add_base_model_params_to_parser(parser): parser.add_argument('--' + constants.REWEIGHTING_RANGE_NAME, type=float, default=0.3, required=False, help='magnitude of data augmentation by randomly weighted average of read embeddings. ' 'a value of x yields random weights between 1 - x and 1 + x') - parser.add_argument('--' + constants.ALT_DOWNSAMPLE_NAME, type=int, default=100, required=False, - help='max number of alt reads to downsample to inside the model') parser.add_argument('--' + constants.BATCH_NORMALIZE_NAME, action='store_true', help='flag to turn on batch normalization') diff --git a/permutect/test/tools/test_train_base_model.py b/permutect/test/tools/test_train_base_model.py index be3b5e5b..0f19f734 100644 --- a/permutect/test/tools/test_train_base_model.py +++ b/permutect/test/tools/test_train_base_model.py @@ -27,7 +27,6 @@ def test_train_base_model(): 'linear/out_features=10'] setattr(train_model_args, constants.REF_SEQ_LAYER_STRINGS_NAME, cnn_layer_strings) setattr(train_model_args, constants.DROPOUT_P_NAME, 0.0) - setattr(train_model_args, constants.ALT_DOWNSAMPLE_NAME, 20) setattr(train_model_args, constants.BATCH_NORMALIZE_NAME, False) setattr(train_model_args, constants.LEARNING_METHOD_NAME, 'SEMISUPERVISED') @@ -39,7 +38,8 @@ def test_train_base_model(): # training hyperparameters setattr(train_model_args, constants.REWEIGHTING_RANGE_NAME, 0.3) setattr(train_model_args, constants.BATCH_SIZE_NAME, 64) - setattr(train_model_args, constants.NUM_WORKERS_NAME, 2) + setattr(train_model_args, constants.INFERENCE_BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.NUM_WORKERS_NAME, 0) setattr(train_model_args, constants.NUM_EPOCHS_NAME, 2) setattr(train_model_args, constants.NUM_CALIBRATION_EPOCHS_NAME, 0) setattr(train_model_args, constants.LEARNING_RATE_NAME, 0.001) diff --git a/permutect/test/tools/test_train_model.py b/permutect/test/tools/test_train_model.py index cd56dc2e..906c44f3 100644 --- a/permutect/test/tools/test_train_model.py +++ b/permutect/test/tools/test_train_model.py @@ -32,7 +32,8 @@ def test_train_model(): # training hyperparameters setattr(train_model_args, constants.BATCH_SIZE_NAME, 64) - setattr(train_model_args, constants.NUM_WORKERS_NAME, 2) + setattr(train_model_args, constants.INFERENCE_BATCH_SIZE_NAME, 64) + setattr(train_model_args, constants.NUM_WORKERS_NAME, 0) setattr(train_model_args, constants.NUM_EPOCHS_NAME, 2) setattr(train_model_args, constants.NUM_CALIBRATION_EPOCHS_NAME, 1) setattr(train_model_args, constants.LEARNING_RATE_NAME, 0.001) diff --git a/permutect/utils.py b/permutect/utils.py index 01540861..b7aca8c6 100644 --- a/permutect/utils.py +++ b/permutect/utils.py @@ -183,6 +183,46 @@ def gamma_binomial(n, k, alpha, beta): return exponent_term + gamma_term - torch.log(n + 1) +# for tensor of shape (R, C...) and row counts n1, n2. . nK, return a tensor of shape (K, C...) whose 1st row is the sum of the +# first n1 rows of the input, 2nd row is the sum of the next n2 rows etc +# note that this works for arbitrary C, including empty. That is, it works for 1D, 2D, 3D etc input. +def sums_over_rows(input_tensor: torch.Tensor, counts: torch.IntTensor): + range_ends = torch.cumsum(counts, dim=0) + assert range_ends[-1] == len(input_tensor) # the counts need to add up! + + row_cumsums = torch.cumsum(input_tensor, dim=0) + + # if counts are eg 1, 2, 3 then range ends are 1, 3, 6 and we are interested in cumsums[0, 2, 5] + relevant_cumsums = row_cumsums[(range_ends - 1).long()] + + # if counts are eg 1, 2, 3 we now have, the sum of the first 1, 3, and 6 rows. To get the sums of row 0, rows 1-2, rows 3-5 + # we need the consecutive differences, with a row of zeroes prepended + row_of_zeroes = torch.zeros_like(relevant_cumsums[0])[None] # the [None] makes it (1xC) + relevant_sums = torch.diff(relevant_cumsums, dim=0, prepend=row_of_zeroes) + return relevant_sums + + +# same but divide by the counts to get means +def means_over_rows(input_tensor: torch.Tensor, counts: torch.IntTensor, keepdim: bool = False): + extra_dims = (1,) * (input_tensor.dim() - 1) + result = sums_over_rows(input_tensor, counts) / counts.view(-1, *extra_dims) + + return torch.repeat_interleave(result, dim=0, repeats=counts) if keepdim else result + + +# same but include a regularizer in case of zeros in the counts vector +# regularizer has the dimension of one row of the input tensor +def means_over_rows_with_regularizer(input_tensor: torch.Tensor, counts: torch.IntTensor, regularizer, regularizer_weight, keepdim: bool = False): + # TODO: left off right here + extra_dims = (1,) * (input_tensor.dim() - 1) + + regularized_sums = sums_over_rows(input_tensor, counts) + regularizer[None, :] + regularized_counts = counts + regularizer_weight + result = regularized_sums / regularized_counts.view(-1, *extra_dims) + + return torch.repeat_interleave(result, dim=0, repeats=counts) if keepdim else result + + class StreamingAverage: def __init__(self): self._count = 0.0