Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
noctillion committed Dec 13, 2024
1 parent 8c8fd2c commit 577f5ac
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
4 changes: 3 additions & 1 deletion transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: st
res = await conn.fetch(query, experiment_result_id)
return tuple([self._deserialize_gene_expression(record) for record in res])

async def fetch_gene_expressions(self, experiments: list[str], method: str = "raw", paginate: bool = False) -> Tuple[Tuple[GeneExpression, ...], int]:
async def fetch_gene_expressions(
self, experiments: list[str], method: str = "raw", paginate: bool = False
) -> Tuple[Tuple[GeneExpression, ...], int]:
if not experiments:
return (), 0
# TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination
Expand Down
4 changes: 1 addition & 3 deletions transcriptomics_data_service/routers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame:
Fetch raw counts from the database for the given experiment_result_id.
Returns a DataFrame with genes as rows and samples as columns.
"""
expressions, _ = await db.fetch_gene_expressions(
experiments=[experiment_result_id], method="raw", paginate=False
)
expressions, _ = await db.fetch_gene_expressions(experiments=[experiment_result_id], method="raw", paginate=False)
if not expressions:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.")

Expand Down
15 changes: 13 additions & 2 deletions transcriptomics_data_service/scripts/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import numpy as np
from joblib import Parallel, delayed


def filter_counts(counts_df):
"""Filter out genes (rows) and samples (columns) with zero total counts."""
row_filter = counts_df.sum(axis=1) > 0
col_filter = counts_df.sum(axis=0) > 0
return counts_df.loc[row_filter, col_filter]


def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None):
"""Align counts and gene_lengths, drop zeros, and optionally scale gene lengths."""
counts_df = counts_df.loc[gene_lengths.index]
Expand All @@ -18,11 +20,13 @@ def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None):
gene_lengths = gene_lengths / scale_length
return filter_counts(counts_df), gene_lengths


def parallel_apply(columns, func, n_jobs=-1):
"""Apply a function to each column in parallel and combine results."""
results = Parallel(n_jobs=n_jobs)(delayed(func)(col) for col in columns)
return pd.concat(results, axis=1)


def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim):
"""Perform log ratio and sum trimming."""
n = len(log_ratio)
Expand All @@ -43,6 +47,7 @@ def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim):

return lr_t[final_idx], w_t[final_idx]


def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1):
"""Compute TMM normalization factors for counts data."""
lib_sizes = counts_df.sum(axis=0)
Expand All @@ -53,7 +58,7 @@ def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.0
sample_names = counts_df.columns
data_values = counts_df.values

norm_factors = pd.Series(index=sample_names, dtype='float64')
norm_factors = pd.Series(index=sample_names, dtype="float64")
norm_factors[ref_sample] = 1.0

def compute_norm_factor(sample):
Expand Down Expand Up @@ -81,7 +86,7 @@ def compute_norm_factor(sample):
lr_final, w_final = trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim)

mean_M = np.sum(w_final * lr_final) / np.sum(w_final)
norm_factor = 2 ** mean_M
norm_factor = 2**mean_M
return sample, norm_factor

samples = [s for s in sample_names if s != ref_sample]
Expand All @@ -93,6 +98,7 @@ def compute_norm_factor(sample):
norm_factors = norm_factors / np.exp(np.mean(np.log(norm_factors)))
return norm_factors


def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1):
"""Perform TMM normalization on counts data."""
counts_df = filter_counts(counts_df)
Expand All @@ -101,21 +107,26 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru
normalized_data = counts_df.div(lib_sizes, axis=1).div(norm_factors, axis=1) * lib_sizes.mean()
return normalized_data


def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1):
"""Perform GeTMM normalization on counts data."""
counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths)
rpk = counts_df.mul(1e3).div(gene_lengths, axis=0)
return tmm_normalization(rpk, logratio_trim, sum_trim, weighting, n_jobs)


def compute_rpk(counts_df, gene_lengths_scaled, n_jobs=-1):
"""Compute RPK values in parallel."""
columns = counts_df.columns

def rpk_col(col):
return counts_df[col] / gene_lengths_scaled

rpk = parallel_apply(columns, rpk_col, n_jobs)
rpk.columns = columns
return rpk


def tpm_normalization(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3, n_jobs=-1):
"""Convert raw read counts to TPM in parallel."""
counts_df, gene_lengths_scaled = prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=scale_length)
Expand Down

0 comments on commit 577f5ac

Please sign in to comment.