From 577f5ac12ed0d3ca4631301cb40697f320988845 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 13 Dec 2024 14:23:07 -0500 Subject: [PATCH] lint --- transcriptomics_data_service/db.py | 4 +++- .../routers/normalization.py | 4 +--- transcriptomics_data_service/scripts/normalize.py | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index 636677a..92f793f 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -135,7 +135,9 @@ async def fetch_gene_expressions_by_experiment_id(self, experiment_result_id: st res = await conn.fetch(query, experiment_result_id) return tuple([self._deserialize_gene_expression(record) for record in res]) - async def fetch_gene_expressions(self, experiments: list[str], method: str = "raw", paginate: bool = False) -> Tuple[Tuple[GeneExpression, ...], int]: + async def fetch_gene_expressions( + self, experiments: list[str], method: str = "raw", paginate: bool = False + ) -> Tuple[Tuple[GeneExpression, ...], int]: if not experiments: return (), 0 # TODO: refactor this fetch_gene_expressions_by_experiment_id and implement pagination diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index d845ce7..439ed3a 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -93,9 +93,7 @@ async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: Fetch raw counts from the database for the given experiment_result_id. Returns a DataFrame with genes as rows and samples as columns. """ - expressions, _ = await db.fetch_gene_expressions( - experiments=[experiment_result_id], method="raw", paginate=False - ) + expressions, _ = await db.fetch_gene_expressions(experiments=[experiment_result_id], method="raw", paginate=False) if not expressions: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Experiment result not found.") diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/scripts/normalize.py index d211dbf..55f9799 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/scripts/normalize.py @@ -2,12 +2,14 @@ import numpy as np from joblib import Parallel, delayed + def filter_counts(counts_df): """Filter out genes (rows) and samples (columns) with zero total counts.""" row_filter = counts_df.sum(axis=1) > 0 col_filter = counts_df.sum(axis=0) > 0 return counts_df.loc[row_filter, col_filter] + def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None): """Align counts and gene_lengths, drop zeros, and optionally scale gene lengths.""" counts_df = counts_df.loc[gene_lengths.index] @@ -18,11 +20,13 @@ def prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=None): gene_lengths = gene_lengths / scale_length return filter_counts(counts_df), gene_lengths + def parallel_apply(columns, func, n_jobs=-1): """Apply a function to each column in parallel and combine results.""" results = Parallel(n_jobs=n_jobs)(delayed(func)(col) for col in columns) return pd.concat(results, axis=1) + def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim): """Perform log ratio and sum trimming.""" n = len(log_ratio) @@ -43,6 +47,7 @@ def trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim): return lr_t[final_idx], w_t[final_idx] + def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): """Compute TMM normalization factors for counts data.""" lib_sizes = counts_df.sum(axis=0) @@ -53,7 +58,7 @@ def compute_TMM_normalization_factors(counts_df, logratio_trim=0.3, sum_trim=0.0 sample_names = counts_df.columns data_values = counts_df.values - norm_factors = pd.Series(index=sample_names, dtype='float64') + norm_factors = pd.Series(index=sample_names, dtype="float64") norm_factors[ref_sample] = 1.0 def compute_norm_factor(sample): @@ -81,7 +86,7 @@ def compute_norm_factor(sample): lr_final, w_final = trim_values(log_ratio, log_mean, w, logratio_trim, sum_trim) mean_M = np.sum(w_final * lr_final) / np.sum(w_final) - norm_factor = 2 ** mean_M + norm_factor = 2**mean_M return sample, norm_factor samples = [s for s in sample_names if s != ref_sample] @@ -93,6 +98,7 @@ def compute_norm_factor(sample): norm_factors = norm_factors / np.exp(np.mean(np.log(norm_factors))) return norm_factors + def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): """Perform TMM normalization on counts data.""" counts_df = filter_counts(counts_df) @@ -101,21 +107,26 @@ def tmm_normalization(counts_df, logratio_trim=0.3, sum_trim=0.05, weighting=Tru normalized_data = counts_df.div(lib_sizes, axis=1).div(norm_factors, axis=1) * lib_sizes.mean() return normalized_data + def getmm_normalization(counts_df, gene_lengths, logratio_trim=0.3, sum_trim=0.05, weighting=True, n_jobs=-1): """Perform GeTMM normalization on counts data.""" counts_df, gene_lengths = prepare_counts_and_lengths(counts_df, gene_lengths) rpk = counts_df.mul(1e3).div(gene_lengths, axis=0) return tmm_normalization(rpk, logratio_trim, sum_trim, weighting, n_jobs) + def compute_rpk(counts_df, gene_lengths_scaled, n_jobs=-1): """Compute RPK values in parallel.""" columns = counts_df.columns + def rpk_col(col): return counts_df[col] / gene_lengths_scaled + rpk = parallel_apply(columns, rpk_col, n_jobs) rpk.columns = columns return rpk + def tpm_normalization(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3, n_jobs=-1): """Convert raw read counts to TPM in parallel.""" counts_df, gene_lengths_scaled = prepare_counts_and_lengths(counts_df, gene_lengths, scale_length=scale_length)