Skip to content

Commit

Permalink
New label-balancing weight scheme; option to use specific samples for…
Browse files Browse the repository at this point in the history
… calibration (#170)
  • Loading branch information
davidbenjamin authored Jan 14, 2025
1 parent b757123 commit 32d587d
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 240 deletions.
149 changes: 84 additions & 65 deletions permutect/architecture/artifact_model.py

Large diffs are not rendered by default.

25 changes: 13 additions & 12 deletions permutect/architecture/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from permutect.architecture.gradient_reversal.module import GradientReversal
from permutect.architecture.mlp import MLP
from permutect.data.base_datum import BaseBatch, DEFAULT_GPU_FLOAT, DEFAULT_CPU_FLOAT
from permutect.data.base_dataset import BaseDataset, ALL_COUNTS_SENTINEL
from permutect.data.base_dataset import BaseDataset, ALL_COUNTS_INDEX
from permutect.metrics.evaluation_metrics import LossMetrics, EmbeddingMetrics, round_up_to_nearest_three, MAX_COUNT
from permutect.parameters import BaseModelParameters, TrainingParameters

Expand All @@ -38,24 +38,23 @@ def sums_over_chunks(tensor2d: torch.Tensor, chunk_size: int):
def calculate_batch_weights(batch, dataset, by_count: bool):
# TODO: we need a parameter to control the relative weight of unlabeled loss to labeled loss
# For batch index n, we want weight[n] = dataset.weights[alt_counts[n], labels[n], variant_types[n]]
counts = batch.get_alt_counts()
sources = batch.get_sources()
counts = batch.get_alt_counts() if by_count else torch.full(size=(len(sources), ), fill_value=ALL_COUNTS_INDEX, dtype=torch.int)
labels = batch.get_labels()
variant_types = batch.get_variant_types()

return utils.index_3d_array(dataset.weights, counts, labels, variant_types) if by_count else \
utils.index_2d_array(dataset.weights[ALL_COUNTS_SENTINEL], labels, variant_types)
return utils.index_4d_array(dataset.label_balancing_weights_sclt, sources, counts, labels, variant_types)


# note: this works for both BaseBatch/BaseDataset AND ArtifactBatch/ArtifactDataset
# if by_count is True, each count is weighted separately for balanced loss within that count
def calculate_batch_source_weights(batch, dataset, by_count: bool):
# For batch index n, we want weight[n] = dataset.source_weights[alt_counts[n], sources[n], variant_types[n]]
counts = batch.get_alt_counts()
# For batch index n, we want weight[n] = dataset.source_weights[sources[n], alt_counts[n], variant_types[n]]
sources = batch.get_sources()
counts = batch.get_alt_counts() if by_count else torch.full(size=(len(sources), ), fill_value=ALL_COUNTS_INDEX, dtype=torch.int)
variant_types = batch.get_variant_types()

return utils.index_3d_array(dataset.source_weights, counts, sources, variant_types) if by_count else \
utils.index_2d_array(dataset.source_weights[ALL_COUNTS_SENTINEL], sources, variant_types)
return utils.index_3d_array(dataset.source_balancing_weights_sct, sources, counts, variant_types)


class LearningMethod(Enum):
Expand Down Expand Up @@ -464,10 +463,12 @@ def learn_base_model(base_model: BaseModel, dataset: BaseDataset, learning_metho
is_cuda = base_model._device.type == 'cuda'
print(f"Is CUDA available? {is_cuda}")

for idx, variation_type in enumerate(utils.Variation):
print(f"For variation type {variation_type.name}, there are {int(dataset.totals[ALL_COUNTS_SENTINEL][Label.ARTIFACT][idx].item())} \
artifacts, {int(dataset.totals[ALL_COUNTS_SENTINEL][Label.VARIANT][idx].item())} \
non-artifacts, and {int(dataset.totals[ALL_COUNTS_SENTINEL][Label.UNLABELED][idx].item())} unlabeled data.")
for source in range(dataset.max_source + 1):
print(f"Data counts for source {source}:")
for var_type in utils.Variation:
print(f"Data counts for variant type {var_type.name}:")
for label in Label:
print(f"{label.name}: {int(dataset.totals_sclt[source][ALL_COUNTS_INDEX][label][var_type].item())}")

# TODO: use Python's match syntax, but this requires updating Python version in the docker
# TODO: hidden_top_layers are hard-coded!
Expand Down
1 change: 1 addition & 0 deletions permutect/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

SOURCES_NAME = 'sources'
SOURCE_NAME = 'source'
CALIBRATION_SOURCES_NAME = 'calibration_sources'

INPUT_NAME = 'input'
OUTPUT_NAME = 'output'
Expand Down
23 changes: 13 additions & 10 deletions permutect/data/artifact_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ def __init__(self, base_dataset: BaseDataset,
base_loader_num_workers=0,
base_loader_batch_size=8192):
self.counts_by_source = base_dataset.counts_by_source
self.totals = base_dataset.totals
self.source_totals = base_dataset.source_totals
self.weights = base_dataset.weights
self.source_weights = base_dataset.source_weights
self.totals_sclt = base_dataset.totals_sclt
self.label_balancing_weights_sclt = base_dataset.label_balancing_weights_sclt
self.source_balancing_weights_sct = base_dataset.source_balancing_weights_sct

self.artifact_data = []
self.num_folds = base_dataset.num_folds
Expand Down Expand Up @@ -77,22 +76,26 @@ def all_but_one_fold(self, fold_to_exclude: int):
def all_folds(self):
return list(range(self.num_folds))

def make_data_loader(self, folds_to_use: List[int], batch_size: int, pin_memory=False, num_workers: int = 0, labeled_only: bool = False):
sampler = SemiSupervisedArtifactBatchSampler(self, batch_size, folds_to_use, labeled_only)
def make_data_loader(self, folds_to_use: List[int], batch_size: int, pin_memory=False, num_workers: int = 0, labeled_only: bool = False, sources_to_use: List[int] = None):
sampler = SemiSupervisedArtifactBatchSampler(self, batch_size, folds_to_use, labeled_only, sources_to_use)
return DataLoader(dataset=self, batch_sampler=sampler, collate_fn=ArtifactBatch, pin_memory=pin_memory, num_workers=num_workers)


# make ArtifactBatches that mix different ref, alt counts, labeled, unlabeled
# with an option to emit only labeled data
class SemiSupervisedArtifactBatchSampler(Sampler):
def __init__(self, dataset: ArtifactDataset, batch_size, folds_to_use: List[int], labeled_only: bool = False):
def __init__(self, dataset: ArtifactDataset, batch_size, folds_to_use: List[int], labeled_only: bool = False, sources_to_use: List[int] = None):
# combine the index lists of all relevant folds
self.indices_to_use = []
source_set = None if sources_to_use is None else set(sources_to_use)

for fold in folds_to_use:
self.indices_to_use.extend(dataset.labeled_indices[fold])
if not labeled_only:
self.indices_to_use.extend(dataset.unlabeled_indices[fold])
indices_in_fold = dataset.labeled_indices[fold] if labeled_only else (dataset.labeled_indices[fold] + dataset.unlabeled_indices[fold])
if sources_to_use is None:
source_indices_in_fold = indices_in_fold
else:
source_indices_in_fold = [idx for idx in indices_in_fold if dataset[idx].get_source() in source_set]
self.indices_to_use.extend(source_indices_in_fold)

self.batch_size = batch_size
self.num_batches = math.ceil(len(self.indices_to_use) // self.batch_size)
Expand Down
134 changes: 70 additions & 64 deletions permutect/data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from mmap_ninja.ragged import RaggedMmap
from permutect import utils
from permutect.data.base_datum import BaseDatum, BaseBatch, load_list_of_base_data, OneDimensionalData
from permutect.utils import Label, MutableInt
from permutect.utils import Label, MutableInt, Variation

TENSORS_PER_BASE_DATUM = 2 # 1) 2D reads (ref and alt), 1) 1D concatenated stuff

# tarfiles on disk take up about 4x as much as the dataset on RAM
TARFILE_TO_RAM_RATIO = 4

ALL_COUNTS_SENTINEL = 0
ALL_COUNTS_INDEX = 0

WEIGHT_PSEUDOCOUNT = 10

Expand Down Expand Up @@ -67,73 +67,73 @@ def __init__(self, data_in_ram: Iterable[BaseDatum] = None, data_tarfile=None, n
# this is used in the batch sampler to make same-shape batches
self.indices_by_fold = [[] for _ in range(num_folds)]

# totals by count, then by label -- ARTIFACT, VARIANT, UNLABELED, then by variant type
# variant type is done as a 1D np array parallel to the one-hot encoding of variant type
# we use a sentinel count value of 0 to denote aggregation over all counts
# eg totals[4][Label.ARTIFACT] = [2,4,6,8,10] means there are 2 artifact SNVs with alt count 4
self.totals = defaultdict(lambda: {label: np.zeros(len(utils.Variation)) for label in Label})
# determine the maximum count and source in order to allocate arrays
max_count = 0
self.max_source = 0
datum: BaseDatum
for datum in self:
max_count = max(datum.alt_count, max_count)
self.max_source = max(datum.source, self.max_source)

# totals by count, then by source (integer) then by variant type
# basically same as above but with source instead of label. Since we don't know a priori how
# many sources there are, we use a default dict
# outer default dict is count, inner is source
self.source_totals = defaultdict(lambda: defaultdict(lambda: np.zeros(len(utils.Variation))))
# totals by source, count, label, variant type
# we use a sentinel count value of 0 to denote aggregation over all counts
self.totals_sclt = np.zeros((self.max_source + 1, max_count + 1, len(Label), len(Variation)))

self.counts_by_source = defaultdict(lambda: MutableInt()) # amount of data for each source (which is an integer key)

for n, datum in enumerate(self):
self.counts_by_source[datum.source].increment()
source = datum.source
self.counts_by_source[source].increment()

fold = n % num_folds
self.indices_by_fold[fold].append(n)

variant_type_idx = datum.get_variant_type()
self.totals[ALL_COUNTS_SENTINEL][datum.label][variant_type_idx] += 1
self.totals[datum.alt_count][datum.label][variant_type_idx] += 1
self.source_totals[ALL_COUNTS_SENTINEL][datum.source][variant_type_idx] += 1
self.source_totals[datum.alt_count][datum.source][variant_type_idx] += 1


# compute weights to balance loss even for unbalanced data
# in the weights array, count == 0 (which never occurs as a real alt count) is the sentinel value for
# aggregation over all alt counts. The array is indexed by count, then label, then variation type
max_count = max(self.totals.keys())
self.weights = np.zeros((max_count + 1, len(Label), len(utils.Variation)))

# similar but indexed by count, then source, then variant type
max_source = max(self.source_totals[ALL_COUNTS_SENTINEL].keys())
self.source_weights = np.zeros((max_count + 1, max_source + 1, len(utils.Variation)))

sources = self.source_totals[ALL_COUNTS_SENTINEL].keys()
for count in self.totals.keys():
# eg: if there are 1000 artifact and 10 non-artifact SNVs, the ratio is 100, and artifacts get a weight of 1/sqrt(100) = 1/10
# while non-artifacts get a weight of 10 -- hence the effective count of each is 1000/10 = 10*10 = 100
art_to_nonart_ratios = ratio_with_pseudocount(self.totals[count][Label.ARTIFACT], self.totals[count][Label.VARIANT])
self.weights[count][Label.VARIANT] = np.sqrt(art_to_nonart_ratios)
self.weights[count][Label.ARTIFACT] = 1 / np.sqrt(art_to_nonart_ratios)

effective_labeled_counts = self.totals[count][Label.ARTIFACT] * self.weights[count][Label.ARTIFACT] + \
self.totals[count][Label.VARIANT] * self.weights[count][Label.VARIANT]

# unlabeled data are weighted down to have at most the same total weight as labeled data
# example, 1000 unlabeled SNVs and 100 labeled SNVs -- unlabeled weight is 100/1000 = 1/10
# example, 10 unlabeled and 100 labeled -- unlabeled weight is 1
self.weights[count][Label.UNLABELED] = np.clip(ratio_with_pseudocount(effective_labeled_counts, self.totals[count][Label.UNLABELED]), 0,1)

# by variant type, for this count
totals_over_sources = np.sum([self.source_totals[count][source] for source in sources])
for source in sources:
self.source_weights[count][source] = np.sqrt(ratio_with_pseudocount(totals_over_sources, self.source_weights[count][source]))

# normalize source prediction weights to have same total effective count. Note that this is modulated
# downstream by set_alpha on the gradient reversal layer applied before source prediction
effective_source_counts = np.sum([self.source_totals[count][source] * self.source_weights[count][source] for source in sources])
source_weight_normalization = effective_labeled_counts / effective_source_counts
for source in sources:
self.source_weights[count][source] = self.source_weights[count][source] * source_weight_normalization

self.weights = torch.from_numpy(self.weights)
self.source_weights = torch.from_numpy(self.source_weights)
self.totals_sclt[source][ALL_COUNTS_INDEX][datum.label][variant_type_idx] += 1
self.totals_sclt[source][datum.alt_count][datum.label][variant_type_idx] += 1

# general balancing idea: if total along some axis eg label is T and count for one particular label is C,
# assign weight T/C -- then effective count is (T/C)*C = T, which is independent of label
# we therefore need sums along certain axes:
totals_sct = np.sum(self.totals_sclt, axis=2) # sum over label for label-balancing
labeled_totals_sct = totals_sct - self.totals_sclt[:, :, Label.UNLABELED, :]
totals_ct = np.sum(totals_sct, axis=0) # sum over label and source for source-balancing
labeled_total = np.sum(labeled_totals_sct)

# note: count == 0 (which never occurs as a real alt count) means aggregation over all alt counts
# thus if we ever sum over count (which we currently don't do), make sure to exclude count == 0

self.label_balancing_weights_sclt = ratio_with_pseudocount(labeled_totals_sct[:, :, None, :], self.totals_sclt)

# next we want to normalize so that the average weight encountered on labeled data is 1 -- this way the learning rate
# parameter has a fixed meaning.
total_weight = np.sum(self.totals_sclt * self.label_balancing_weights_sclt)
total_supervised_weight = total_weight - np.sum(self.totals_sclt[:, :, Label.UNLABELED, :] * self.label_balancing_weights_sclt[:, :, Label.UNLABELED, :])
average_supervised_weight = total_supervised_weight / labeled_total

# after the following line, average label-balancing weight encountered on labeled data is 1
self.label_balancing_weights_sclt = self.label_balancing_weights_sclt / average_supervised_weight

# the balancing process can reduce the influence of unlabeled data to match that of labeled data, but we don't want to
# weight it strongly when there's little unlabeled data. That is, if we have plenty of labeled data we are happy with
# supervised learning!
self.label_balancing_weights_sclt[:, :, Label.UNLABELED, :] = \
np.clip(self.label_balancing_weights_sclt[:, :, Label.UNLABELED, :], 0, 1)

# at this point, average labeled weight is 1 and weights balance artifacts with non-artifacts for each combination
# of source, count, and variant type

# weights for adversarial source prediction task. Balance over sources for each count and variant type
self.source_balancing_weights_sct = ratio_with_pseudocount(totals_ct[None, :, :], totals_sct)

# we now normalize the source balancing weight to have the same total weights as supervised learning
# the average supervised count has been normalized to 1 so the total supervised weight is just the total labeled
# count.
total_source_balancing_weight = np.sum(totals_sct * self.source_balancing_weights_sct)
self.source_balancing_weights_sct = self.source_balancing_weights_sct * labeled_total / total_source_balancing_weight

self.label_balancing_weights_sclt = torch.from_numpy(self.label_balancing_weights_sclt)
self.source_balancing_weights_sct = torch.from_numpy(self.source_balancing_weights_sct)
self.num_read_features = self[0].get_reads_2d().shape[1]
self.num_info_features = len(self[0].get_info_tensor_1d())
self.ref_sequence_length = len(self[0].get_ref_sequence_1d())
Expand Down Expand Up @@ -165,8 +165,8 @@ def all_but_one_fold(self, fold_to_exclude: int):
def all_folds(self):
return list(range(self.num_folds))

def make_data_loader(self, folds_to_use: List[int], batch_size: int, pin_memory=False, num_workers: int = 0):
sampler = SemiSupervisedBatchSampler(self, batch_size, folds_to_use)
def make_data_loader(self, folds_to_use: List[int], batch_size: int, pin_memory=False, num_workers: int = 0, sources_to_use: List[int] = None):
sampler = SemiSupervisedBatchSampler(self, batch_size, folds_to_use, sources_to_use)
return DataLoader(dataset=self, batch_sampler=sampler, collate_fn=BaseBatch, pin_memory=pin_memory, num_workers=num_workers)


Expand Down Expand Up @@ -199,12 +199,18 @@ def chunk(lis, chunk_size):
# the artifact model handles weighting the losses to compensate for class imbalance between supervised and unsupervised
# thus the sampler is not responsible for balancing the data
class SemiSupervisedBatchSampler(Sampler):
def __init__(self, dataset: BaseDataset, batch_size, folds_to_use: List[int]):
def __init__(self, dataset: BaseDataset, batch_size, folds_to_use: List[int], sources_to_use: List[int] = None):
# combine the index maps of all relevant folds
self.indices_to_use = []

source_set = None if sources_to_use is None else set(sources_to_use)
for fold in folds_to_use:
self.indices_to_use.extend(dataset.indices_by_fold[fold])
indices_in_fold = dataset.indices_by_fold[fold]
if sources_to_use is None:
source_indices_in_fold = indices_in_fold
else:
source_indices_in_fold = [idx for idx in indices_in_fold if dataset[idx].source in source_set]

self.indices_to_use.extend(source_indices_in_fold)

self.batch_size = batch_size
self.num_batches = math.ceil(len(self.indices_to_use) // self.batch_size)
Expand Down
6 changes: 3 additions & 3 deletions permutect/data/base_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def get_variant_type(self) -> int:
def set_variant_type(self, variant_type: Variation):
self.array[self.__class__.VAR_TYPE_IDX] = variant_type

def get_label(self):
return self.array[self.__class__.LABEL_IDX]
def get_label(self) -> int:
return round(self.array[self.__class__.LABEL_IDX])

def set_label(self, label: Label):
self.array[self.__class__.LABEL_IDX] = label
Expand Down Expand Up @@ -651,7 +651,7 @@ def get_depth(self) -> int:
def get_variant_type(self) -> int:
return self.one_dimensional_data.get_variant_type()

def get_label(self):
def get_label(self) -> int:
return self.one_dimensional_data.get_label()

def get_source(self) -> int:
Expand Down
Loading

0 comments on commit 32d587d

Please sign in to comment.