Skip to content

Commit

Permalink
Merge branch 'normalization' into paralell_normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
v-rocheleau committed Jan 8, 2025
2 parents 39fbcd0 + 442b735 commit 474c87d
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 51 deletions.
17 changes: 6 additions & 11 deletions transcriptomics_data_service/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,10 @@
from functools import lru_cache
from pathlib import Path


from .config import Config, ConfigDependency
from .logger import LoggerDependency
from .models import ExperimentResult, GeneExpression

NORM_METHOD_COLS = {
"tpm": "tpm_count",
"tmm": "tmm_count",
"getmm": "getmm_count",
}
from .models import CountTypesEnum, ExperimentResult, GeneExpression, NormalizationMethodEnum

SCHEMA_PATH = Path(__file__).parent / "sql" / "schema.sql"

Expand Down Expand Up @@ -135,11 +130,11 @@ def _deserialize_gene_expression(self, rec: asyncpg.Record) -> GeneExpression:
# Normalization Methods
############################

async def update_normalized_expressions(self, expressions: List[GeneExpression], method: str):
async def update_normalized_expressions(self, expressions: List[GeneExpression], method: NormalizationMethodEnum):
"""
Update the normalized expressions in the database using batch updates.
"""
column = NORM_METHOD_COLS.get(method)
column = f"{method.value}_count"
if not column:
raise ValueError(f"Unsupported normalization method: {method}")
conn: asyncpg.Connection
Expand All @@ -157,7 +152,7 @@ async def update_normalized_expressions(self, expressions: List[GeneExpression],
]

await conn.execute(
f"""
"""
CREATE TEMPORARY TABLE temp_updates (
value DOUBLE PRECISION,
experiment_result_id VARCHAR(255),
Expand Down Expand Up @@ -199,7 +194,7 @@ async def fetch_gene_expressions(
genes: Optional[List[str]] = None,
experiments: Optional[List[str]] = None,
sample_ids: Optional[List[str]] = None,
method: str = "raw",
method: CountTypesEnum = CountTypesEnum.raw,
page: int = 1,
page_size: int = 100,
paginate: bool = True,
Expand Down
32 changes: 23 additions & 9 deletions transcriptomics_data_service/models.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
from pydantic import BaseModel, Field, validator, field_validator
from pydantic import BaseModel, Field
from typing import List, Optional
from enum import Enum

__all__ = [
"ExperimentResult",
"GeneExpression",
"GeneExpressionData",
"PaginationMeta",
"GeneExpressionResponse",
"MethodEnum",
"NormalizationMethodEnum",
"ExpressionQueryBody",
]


class MethodEnum(str, Enum):
raw = "raw"
tpm = "tpm"
tmm = "tmm"
getmm = "getmm"
TPM = "tpm"
TMM = "tmm"
GETMM = "getmm"
RAW = "raw"


class NormalizationMethodEnum(str, Enum):
tpm = TPM
tmm = TMM
getmm = GETMM


class CountTypesEnum(str, Enum):
raw = RAW
# normalized counts
tpm = TPM
tmm = TMM
getmm = GETMM


class PaginatedRequest(BaseModel):
Expand Down Expand Up @@ -57,7 +69,9 @@ class ExpressionQueryBody(PaginatedRequest):
genes: Optional[List[str]] = Field(None, description="List of gene codes to retrieve")
experiments: Optional[List[str]] = Field(None, description="List of experiment result IDs to retrieve data from")
sample_ids: Optional[List[str]] = Field(None, description="List of sample IDs to retrieve data from")
method: MethodEnum = Field(MethodEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'")
method: CountTypesEnum = Field(
CountTypesEnum.raw, description="Data method to retrieve: 'raw', 'tpm', 'tmm', 'getmm'"
)


class GeneExpressionResponse(PaginatedResponse):
Expand Down
51 changes: 20 additions & 31 deletions transcriptomics_data_service/routers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,15 @@
from io import StringIO

from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.models import GeneExpression
from transcriptomics_data_service.scripts.normalize import (
tpm_normalization,
tmm_normalization,
getmm_normalization,
)

# Constants for normalization methods
NORM_TPM = "tpm"
NORM_TMM = "tmm"
NORM_GETMM = "getmm"
from transcriptomics_data_service.models import CountTypesEnum, GeneExpression, NormalizationMethodEnum
from transcriptomics_data_service.normalization_utils import getmm_normalization, tmm_normalization, tpm_normalization

# List of all valid normalization methods
VALID_METHODS = [NORM_TPM, NORM_TMM, NORM_GETMM]

__all__ = ["normalization_router"]


REQUIRES_GENES_LENGHTS = [NormalizationMethodEnum.tpm, NormalizationMethodEnum.getmm]

normalization_router = APIRouter(prefix="/normalize")


Expand All @@ -28,22 +20,17 @@
status_code=status.HTTP_200_OK,
)
async def normalize(
experiment_result_id: str,
method: str,
db: DatabaseDependency,
experiment_result_id: str,
method: NormalizationMethodEnum,
gene_lengths_file: UploadFile = File(None),
):
"""
Normalize gene expressions using the specified method for a given experiment_result_id.
"""
# Method validation
if method.lower() not in VALID_METHODS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported normalization method: {method}"
)

# Load gene lengths if required
if method.lower() in [NORM_TPM, NORM_GETMM]:
if method.lower() in [NormalizationMethodEnum.tpm, NormalizationMethodEnum.getmm]:
if gene_lengths_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -57,17 +44,17 @@ async def normalize(
raw_counts_df = await _fetch_raw_counts(db, experiment_result_id)

# Perform normalization
if method.lower() == NORM_TPM:
if method is NormalizationMethodEnum.tpm:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = tpm_normalization(raw_counts_df, gene_lengths_series)
elif method.lower() == NORM_TMM:
elif method is NormalizationMethodEnum.tmm:
normalized_df = tmm_normalization(raw_counts_df)
elif method.lower() == NORM_GETMM:
elif method is NormalizationMethodEnum.getmm:
raw_counts_df, gene_lengths_series = _align_gene_lengths(raw_counts_df, gene_lengths)
normalized_df = getmm_normalization(raw_counts_df, gene_lengths_series)

# Update database with normalized values
await _update_normalized_values(db, normalized_df, experiment_result_id, method=method.lower())
await _update_normalized_values(db, normalized_df, experiment_result_id, method)

return {"message": f"{method.upper()} normalization completed successfully"}

Expand All @@ -88,7 +75,7 @@ async def _load_gene_lengths(gene_lengths_file: UploadFile) -> pd.Series:
return gene_lengths_series


async def _fetch_raw_counts(db, experiment_result_id: str) -> pd.DataFrame:
async def _fetch_raw_counts(db: DatabaseDependency, experiment_result_id: str) -> pd.DataFrame:
"""
Fetch raw counts from the database for the given experiment_result_id.
Returns a DataFrame with genes as rows and samples as columns.
Expand Down Expand Up @@ -122,13 +109,15 @@ def _align_gene_lengths(raw_counts_df: pd.DataFrame, gene_lengths: pd.Series):
return raw_counts_df, gene_lengths_series


async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_result_id: str, method: str):
async def _update_normalized_values(
db: DatabaseDependency, normalized_df: pd.DataFrame, experiment_result_id: str, method: NormalizationMethodEnum
):
"""
Update the normalized values in the database.
"""
# Fetch existing expressions to get raw_count values
existing_expressions, _ = await db.fetch_gene_expressions(
experiments=[experiment_result_id], method="raw", paginate=False
experiments=[experiment_result_id], method=CountTypesEnum.raw, paginate=False
)
raw_count_dict = {(expr.gene_code, expr.sample_id): expr.raw_count for expr in existing_expressions}

Expand All @@ -154,9 +143,9 @@ async def _update_normalized_values(db, normalized_df: pd.DataFrame, experiment_
sample_id=sample_id,
experiment_result_id=experiment_result_id,
raw_count=raw_count,
tpm_count=row["NormalizedValue"] if method == NORM_TPM else None,
tmm_count=row["NormalizedValue"] if method == NORM_TMM else None,
getmm_count=row["NormalizedValue"] if method == NORM_GETMM else None,
tpm_count=row["NormalizedValue"] if method == NormalizationMethodEnum.tpm else None,
tmm_count=row["NormalizedValue"] if method == NormalizationMethodEnum.tmm else None,
getmm_count=row["NormalizedValue"] if method == NormalizationMethodEnum.getmm else None,
)
expressions.append(gene_expression)

Expand Down
87 changes: 87 additions & 0 deletions transcriptomics_data_service/routers/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from fastapi import APIRouter, HTTPException, status, Query

from transcriptomics_data_service.db import DatabaseDependency
from transcriptomics_data_service.logger import LoggerDependency
from transcriptomics_data_service.models import (
GeneExpressionData,
GeneExpressionResponse,
NormalizationMethodEnum,
ExpressionQueryBody,
)

query_router = APIRouter()


async def get_expressions_handler(
params: ExpressionQueryBody,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Handler for fetching and returning gene expression data.
"""
logger.info(f"Received query parameters: {params}")

expressions, total_records = await db.fetch_gene_expressions(
genes=params.genes,
experiments=params.experiments,
sample_ids=params.sample_ids,
method=params.method.value,
page=params.page,
page_size=params.page_size,
)

if not expressions:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No gene expression data found for the given parameters.",
)

response_data = []
method_field = f"{params.method.value}_count" if params.method != NormalizationMethodEnum.raw else "raw_count"
for expr in expressions:
count = getattr(expr, method_field)
response_item = GeneExpressionData(
gene_code=expr.gene_code,
sample_id=expr.sample_id,
experiment_result_id=expr.experiment_result_id,
count=count,
method=method_field,
)
response_data.append(response_item)

total_pages = (total_records + params.page_size - 1) // params.page_size

return GeneExpressionResponse(
expressions=response_data,
total_records=total_records,
page=params.page,
page_size=params.page_size,
total_pages=total_pages,
)


@query_router.post(
"/expressions",
status_code=status.HTTP_200_OK,
response_model=GeneExpressionResponse,
)
async def get_expressions_post(
params: ExpressionQueryBody,
db: DatabaseDependency,
logger: LoggerDependency,
):
"""
Retrieve gene expression data via POST request.
Example JSON body:
{
"genes": ["gene1", "gene2"],
"experiments": ["exp1"],
"sample_ids": ["sample1"],
"method": "tmm",
"page": 1,
"page_size": 100
}
"""
return await get_expressions_handler(params, db, logger)

0 comments on commit 474c87d

Please sign in to comment.