From f3668186329d75155eff044882b109f1e3d88948 Mon Sep 17 00:00:00 2001 From: Victor Rocheleau Date: Mon, 16 Dec 2024 12:49:02 -0500 Subject: [PATCH 1/2] refact: norm algos enum model for implicit validation --- transcriptomics_data_service/db.py | 15 +++----- transcriptomics_data_service/models.py | 9 +++++ .../routers/normalization.py | 37 +++++++------------ 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index 91a4cad..75e8215 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -7,15 +7,10 @@ from functools import lru_cache from pathlib import Path + from .config import Config, ConfigDependency from .logger import LoggerDependency -from .models import ExperimentResult, GeneExpression - -NORM_METHOD_COLS = { - "tpm": "tpm_count", - "tmm": "tmm_count", - "getmm": "getmm_count", -} +from .models import ExperimentResult, GeneExpression, NormalizationAlgos SCHEMA_PATH = Path(__file__).parent / "sql" / "schema.sql" @@ -150,11 +145,11 @@ def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression: # Normalization Methods ############################ - async def update_normalized_expressions(self, expressions: List[GeneExpression], method: str): + async def update_normalized_expressions(self, expressions: List[GeneExpression], method: NormalizationAlgos): """ Update the normalized expressions in the database using batch updates. """ - column = NORM_METHOD_COLS.get(method) + column = f"{method.value}_count" if not column: raise ValueError(f"Unsupported normalization method: {method}") conn: asyncpg.Connection @@ -172,7 +167,7 @@ async def update_normalized_expressions(self, expressions: List[GeneExpression], ] await conn.execute( - f""" + """ CREATE TEMPORARY TABLE temp_updates ( value DOUBLE PRECISION, experiment_result_id VARCHAR(255), diff --git a/transcriptomics_data_service/models.py b/transcriptomics_data_service/models.py index ab4317b..aed7c9f 100644 --- a/transcriptomics_data_service/models.py +++ b/transcriptomics_data_service/models.py @@ -1,8 +1,10 @@ +from enum import Enum from pydantic import BaseModel __all__ = [ "ExperimentResult", "GeneExpression", + "NormalizationAlgos" ] @@ -20,3 +22,10 @@ class GeneExpression(BaseModel): tpm_count: float | None = None tmm_count: float | None = None getmm_count: float | None = None + + +class NormalizationAlgos(str, Enum): + # Constants for normalization methods + TPM = "tpm" + TMM = "tmm" + GETMM = "getmm" diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index a507208..111b84e 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -1,25 +1,19 @@ +from enum import Enum from fastapi import APIRouter, HTTPException, UploadFile, File, status import pandas as pd from io import StringIO from transcriptomics_data_service.db import DatabaseDependency -from transcriptomics_data_service.models import GeneExpression +from transcriptomics_data_service.models import GeneExpression, NormalizationAlgos from transcriptomics_data_service.scripts.normalize import ( read_counts2tpm, tmm_normalization, getmm_normalization, ) -# Constants for normalization methods -NORM_TPM = "tpm" -NORM_TMM = "tmm" -NORM_GETMM = "getmm" - -# List of all valid normalization methods -VALID_METHODS = [NORM_TPM, NORM_TMM, NORM_GETMM] - __all__ = ["normalization_router"] + normalization_router = APIRouter(prefix="/normalize") @@ -29,21 +23,16 @@ ) async def normalize( experiment_result_id: str, - method: str, + method: NormalizationAlgos, db: DatabaseDependency, gene_lengths_file: UploadFile = File(None), ): """ Normalize gene expressions using the specified method for a given experiment_result_id. """ - # method validation - if method not in VALID_METHODS: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}" - ) # load gene lengths - if method in [NORM_TPM, NORM_GETMM]: + if method in [NormalizationAlgos.TPM, NormalizationAlgos.GETMM]: if gene_lengths_file is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -54,12 +43,12 @@ async def normalize( raw_counts_df = await _fetch_raw_counts(db, experiment_result_id) # normalization - if method == NORM_TPM: + if method == NormalizationAlgos.TPM: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = read_counts2tpm(raw_counts_df, gene_lengths_series) - elif method == NORM_TMM: + elif method == NormalizationAlgos.TMM: normalized_df = tmm_normalization(raw_counts_df) - elif method == NORM_GETMM: + elif method == NormalizationAlgos.GETMM: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series) @@ -114,7 +103,9 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series): return raw_counts_df, gene_lengths_series -async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str): +async def _update_normalized_values( + db, normalized_df: pd.DataFrame, experiment_result_id: str, method: NormalizationAlgos +): """ Update the normalized values in the database """ @@ -143,9 +134,9 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_ sample_id=sample_id, experiment_result_id=experiment_result_id, raw_count=raw_count, - tpm_count=row["NormalizedValue"] if method == NORM_TPM else None, - tmm_count=row["NormalizedValue"] if method == NORM_TMM else None, - getmm_count=row["NormalizedValue"] if method == NORM_GETMM else None, + tpm_count=row["NormalizedValue"] if method == NormalizationAlgos.TPM else None, + tmm_count=row["NormalizedValue"] if method == NormalizationAlgos.TMM else None, + getmm_count=row["NormalizedValue"] if method == NormalizationAlgos.GETMM else None, ) expressions.append(gene_expression) From 442b7355b9b27ce77c2b74fe18ab6c27857cd1a0 Mon Sep 17 00:00:00 2001 From: Victor Rocheleau Date: Mon, 16 Dec 2024 15:15:40 -0500 Subject: [PATCH 2/2] refact: reorganize normalization algos --- .../normalize.py => normalization_utils.py} | 17 ++++------------- .../routers/normalization.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 19 deletions(-) rename transcriptomics_data_service/{scripts/normalize.py => normalization_utils.py} (79%) diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/normalization_utils.py similarity index 79% rename from transcriptomics_data_service/scripts/normalize.py rename to transcriptomics_data_service/normalization_utils.py index a1135ca..1d95d9c 100644 --- a/transcriptomics_data_service/scripts/normalize.py +++ b/transcriptomics_data_service/normalization_utils.py @@ -1,7 +1,7 @@ import pandas as pd +import conorm - -def read_counts2tpm(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3): +def read_counts2tpm(counts_df: pd.DataFrame, gene_lengths: pd.Series, scale_library=1e6, scale_length=1e3): """ Convert raw read counts to TPM (Transcripts Per Million). @@ -32,7 +32,7 @@ def read_counts2tpm(counts_df, gene_lengths, scale_library=1e6, scale_length=1e3 return tpm -def tmm_normalization(counts_df): +def tmm_normalization(counts_df: pd.DataFrame): """ Perform TMM normalization on counts data. @@ -42,16 +42,12 @@ def tmm_normalization(counts_df): Returns: DataFrame: TMM-normalized values. """ - try: - import conorm - except ImportError: - raise ImportError("The 'conorm' package is required for this function but is not installed.") normalized_array = conorm.tmm(counts_df) normalized_df = pd.DataFrame(normalized_array, columns=counts_df.columns, index=counts_df.index) return normalized_df -def getmm_normalization(counts_df, gene_lengths): +def getmm_normalization(counts_df: pd.DataFrame, gene_lengths: pd.Series): """ Perform GeTMM normalization on counts data. @@ -62,11 +58,6 @@ def getmm_normalization(counts_df, gene_lengths): Returns: DataFrame: GeTMM-normalized values. """ - try: - import conorm - except ImportError: - raise ImportError("The 'conorm' package is required for this function but is not installed.") - normalized_array = conorm.getmm(counts_df, gene_lengths) normalized_df = pd.DataFrame(normalized_array, columns=counts_df.columns, index=counts_df.index) return normalized_df diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index 111b84e..6f9d88f 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -5,7 +5,7 @@ from transcriptomics_data_service.db import DatabaseDependency from transcriptomics_data_service.models import GeneExpression, NormalizationAlgos -from transcriptomics_data_service.scripts.normalize import ( +from transcriptomics_data_service.normalization_utils import ( read_counts2tpm, tmm_normalization, getmm_normalization, @@ -14,6 +14,8 @@ __all__ = ["normalization_router"] +REQUIRES_GENES_LENGHTS = [NormalizationAlgos.TPM, NormalizationAlgos.GETMM] + normalization_router = APIRouter(prefix="/normalize") @@ -22,17 +24,17 @@ status_code=status.HTTP_200_OK, ) async def normalize( + db: DatabaseDependency, experiment_result_id: str, method: NormalizationAlgos, - db: DatabaseDependency, gene_lengths_file: UploadFile = File(None), ): """ Normalize gene expressions using the specified method for a given experiment_result_id. """ - # load gene lengths - if method in [NormalizationAlgos.TPM, NormalizationAlgos.GETMM]: + # load gene lengths if required + if method in REQUIRES_GENES_LENGHTS: if gene_lengths_file is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -69,7 +71,7 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series: return gene_lengths_series -async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: +async def _fetch_raw_counts(db: DatabaseDependency, experiment_result_id: str) -> pd.DataFrame: """ Fetch raw counts from the database for the given experiment_result_id. Returns a DataFrame with genes as rows and samples as columns. @@ -104,7 +106,7 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series): async def _update_normalized_values( - db, normalized_df: pd.DataFrame, experiment_result_id: str, method: NormalizationAlgos + db: DatabaseDependency, normalized_df: pd.DataFrame, experiment_result_id: str, method: NormalizationAlgos ): """ Update the normalized values in the database