diff --git a/transcriptomics_data_service/db.py b/transcriptomics_data_service/db.py index c839f4f..4bdfec4 100644 --- a/transcriptomics_data_service/db.py +++ b/transcriptomics_data_service/db.py @@ -7,15 +7,10 @@ from functools import lru_cache from pathlib import Path + from .config import Config, ConfigDependency from .logger import LoggerDependency -from .models import ExperimentResult, GeneExpression - -NORM_METHOD_COLS = { - "tpm": "tpm_count", - "tmm": "tmm_count", - "getmm": "getmm_count", -} +from .models import CountTypesEnum, ExperimentResult, GeneExpression, NormalizationMethodEnum SCHEMA_PATH = Path(__file__).parent / "sql" / "schema.sql" @@ -135,11 +130,11 @@ def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression: # Normalization Methods ############################ - async def update_normalized_expressions(self, expressions: List[GeneExpression], method: str): + async def update_normalized_expressions(self, expressions: List[GeneExpression], method: NormalizationMethodEnum): """ Update the normalized expressions in the database using batch updates. """ - column = NORM_METHOD_COLS.get(method) + column = f"{method.value}_count" if not column: raise ValueError(f"Unsupported normalization method: {method}") conn: asyncpg.Connection @@ -157,7 +152,7 @@ async def update_normalized_expressions(self, expressions: List[GeneExpression], ] await conn.execute( - f""" + """ CREATE TEMPORARY TABLE temp_updates ( value DOUBLE PRECISION, experiment_result_id VARCHAR(255), @@ -199,7 +194,7 @@ async def fetch_gene_expressions( genes: Optional[List[str]] = None, experiments: Optional[List[str]] = None, sample_ids: Optional[List[str]] = None, - method: str = "raw", + method: CountTypesEnum = CountTypesEnum.raw, page: int = 1, page_size: int = 100, paginate: bool = True, diff --git a/transcriptomics_data_service/models.py b/transcriptomics_data_service/models.py index 51f8e5c..41adfbd 100644 --- a/transcriptomics_data_service/models.py +++ b/transcriptomics_data_service/models.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, validator, field_validator +from pydantic import BaseModel, Field from typing import List, Optional from enum import Enum @@ -6,18 +6,30 @@ "ExperimentResult", "GeneExpression", "GeneExpressionData", - "PaginationMeta", "GeneExpressionResponse", - "MethodEnum", + "NormalizationMethodEnum", "ExpressionQueryBody", ] -class MethodEnum(str, Enum): - raw = "raw" - tpm = "tpm" - tmm = "tmm" - getmm = "getmm" +TPM = "tpm" +TMM = "tmm" +GETMM = "getmm" +RAW = "raw" + + +class NormalizationMethodEnum(str, Enum): + tpm = TPM + tmm = TMM + getmm = GETMM + + +class CountTypesEnum(str, Enum): + raw = RAW + # normalized counts + tpm = TPM + tmm = TMM + getmm = GETMM class PaginatedRequest(BaseModel): @@ -57,7 +69,9 @@ class ExpressionQueryBody(PaginatedRequest): genes: Optional[List[str]] = Field(None, description="List of gene codes to retrieve") experiments: Optional[List[str]] = Field(None, description="List of experiment result IDs to retrieve data from") sample_ids: Optional[List[str]] = Field(None, description="List of sample IDs to retrieve data from") - method: MethodEnum = Field(MethodEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'") + method: CountTypesEnum = Field( + CountTypesEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'" + ) class GeneExpressionResponse(PaginatedResponse): diff --git a/transcriptomics_data_service/scripts/normalize.py b/transcriptomics_data_service/normalization_utils.py similarity index 100% rename from transcriptomics_data_service/scripts/normalize.py rename to transcriptomics_data_service/normalization_utils.py diff --git a/transcriptomics_data_service/routers/normalization.py b/transcriptomics_data_service/routers/normalization.py index 439ed3a..953cc82 100644 --- a/transcriptomics_data_service/routers/normalization.py +++ b/transcriptomics_data_service/routers/normalization.py @@ -3,23 +3,15 @@ from io import StringIO from transcriptomics_data_service.db import DatabaseDependency -from transcriptomics_data_service.models import GeneExpression -from transcriptomics_data_service.scripts.normalize import ( - tpm_normalization, - tmm_normalization, - getmm_normalization, -) - -# Constants for normalization methods -NORM_TPM = "tpm" -NORM_TMM = "tmm" -NORM_GETMM = "getmm" +from transcriptomics_data_service.models import CountTypesEnum, GeneExpression, NormalizationMethodEnum +from transcriptomics_data_service.normalization_utils import getmm_normalization, tmm_normalization, tpm_normalization -# List of all valid normalization methods -VALID_METHODS = [NORM_TPM, NORM_TMM, NORM_GETMM] __all__ = ["normalization_router"] + +REQUIRES_GENES_LENGHTS = [NormalizationMethodEnum.tpm, NormalizationMethodEnum.getmm] + normalization_router = APIRouter(prefix="/normalize") @@ -28,22 +20,17 @@ status_code=status.HTTP_200_OK, ) async def normalize( - experiment_result_id: str, - method: str, db: DatabaseDependency, + experiment_result_id: str, + method: NormalizationMethodEnum, gene_lengths_file: UploadFile = File(None), ): """ Normalize gene expressions using the specified method for a given experiment_result_id. """ - # Method validation - if method.lower() not in VALID_METHODS: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}" - ) # Load gene lengths if required - if method.lower() in [NORM_TPM, NORM_GETMM]: + if method.lower() in [NormalizationMethodEnum.tpm, NormalizationMethodEnum.getmm]: if gene_lengths_file is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -57,17 +44,17 @@ async def normalize( raw_counts_df = await _fetch_raw_counts(db, experiment_result_id) # Perform normalization - if method.lower() == NORM_TPM: + if method is NormalizationMethodEnum.tpm: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = tpm_normalization(raw_counts_df, gene_lengths_series) - elif method.lower() == NORM_TMM: + elif method is NormalizationMethodEnum.tmm: normalized_df = tmm_normalization(raw_counts_df) - elif method.lower() == NORM_GETMM: + elif method is NormalizationMethodEnum.getmm: raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths) normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series) # Update database with normalized values - await _update_normalized_values(db, normalized_df, experiment_result_id, method=method.lower()) + await _update_normalized_values(db, normalized_df, experiment_result_id, method) return {"message": f"{method.upper()} normalization completed successfully"} @@ -88,7 +75,7 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series: return gene_lengths_series -async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame: +async def _fetch_raw_counts(db: DatabaseDependency, experiment_result_id: str) -> pd.DataFrame: """ Fetch raw counts from the database for the given experiment_result_id. Returns a DataFrame with genes as rows and samples as columns. @@ -122,13 +109,15 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series): return raw_counts_df, gene_lengths_series -async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str): +async def _update_normalized_values( + db: DatabaseDependency, normalized_df: pd.DataFrame, experiment_result_id: str, method: NormalizationMethodEnum +): """ Update the normalized values in the database. """ # Fetch existing expressions to get raw_count values existing_expressions, _ = await db.fetch_gene_expressions( - experiments=[experiment_result_id], method="raw", paginate=False + experiments=[experiment_result_id], method=CountTypesEnum.raw, paginate=False ) raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions} @@ -154,9 +143,9 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_ sample_id=sample_id, experiment_result_id=experiment_result_id, raw_count=raw_count, - tpm_count=row["NormalizedValue"] if method == NORM_TPM else None, - tmm_count=row["NormalizedValue"] if method == NORM_TMM else None, - getmm_count=row["NormalizedValue"] if method == NORM_GETMM else None, + tpm_count=row["NormalizedValue"] if method == NormalizationMethodEnum.tpm else None, + tmm_count=row["NormalizedValue"] if method == NormalizationMethodEnum.tmm else None, + getmm_count=row["NormalizedValue"] if method == NormalizationMethodEnum.getmm else None, ) expressions.append(gene_expression) diff --git a/transcriptomics_data_service/routers/query.py b/transcriptomics_data_service/routers/query.py new file mode 100644 index 0000000..e1e4b67 --- /dev/null +++ b/transcriptomics_data_service/routers/query.py @@ -0,0 +1,87 @@ +from fastapi import APIRouter, HTTPException, status, Query + +from transcriptomics_data_service.db import DatabaseDependency +from transcriptomics_data_service.logger import LoggerDependency +from transcriptomics_data_service.models import ( + GeneExpressionData, + GeneExpressionResponse, + NormalizationMethodEnum, + ExpressionQueryBody, +) + +query_router = APIRouter() + + +async def get_expressions_handler( + params: ExpressionQueryBody, + db: DatabaseDependency, + logger: LoggerDependency, +): + """ + Handler for fetching and returning gene expression data. + """ + logger.info(f"Received query parameters: {params}") + + expressions, total_records = await db.fetch_gene_expressions( + genes=params.genes, + experiments=params.experiments, + sample_ids=params.sample_ids, + method=params.method.value, + page=params.page, + page_size=params.page_size, + ) + + if not expressions: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No gene expression data found for the given parameters.", + ) + + response_data = [] + method_field = f"{params.method.value}_count" if params.method != NormalizationMethodEnum.raw else "raw_count" + for expr in expressions: + count = getattr(expr, method_field) + response_item = GeneExpressionData( + gene_code=expr.gene_code, + sample_id=expr.sample_id, + experiment_result_id=expr.experiment_result_id, + count=count, + method=method_field, + ) + response_data.append(response_item) + + total_pages = (total_records + params.page_size - 1) // params.page_size + + return GeneExpressionResponse( + expressions=response_data, + total_records=total_records, + page=params.page, + page_size=params.page_size, + total_pages=total_pages, + ) + + +@query_router.post( + "/expressions", + status_code=status.HTTP_200_OK, + response_model=GeneExpressionResponse, +) +async def get_expressions_post( + params: ExpressionQueryBody, + db: DatabaseDependency, + logger: LoggerDependency, +): + """ + Retrieve gene expression data via POST request. + + Example JSON body: + { + "genes": ["gene1", "gene2"], + "experiments": ["exp1"], + "sample_ids": ["sample1"], + "method": "tmm", + "page": 1, + "page_size": 100 + } + """ + return await get_expressions_handler(params, db, logger)