Skip to content

Commit

Permalink
update ans
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Jan 18, 2025
1 parent 2baf956 commit 8037b1c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 373 deletions.
47 changes: 43 additions & 4 deletions dance/atlas/sc_similarity/anndata_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,32 @@
from scipy.spatial.distance import cdist, directed_hausdorff, jaccard, jensenshannon
from sklearn.metrics.pairwise import cosine_similarity, rbf_kernel

from dance.settings import METADIR

# Suppress scipy warnings for constant input in Pearson correlation
warnings.filterwarnings("ignore", message="An input array is constant")
from dance.datasets.singlemodality import CellTypeAnnotationDataset


def get_anndata(tissue: str = "Blood", species: str = "human", filetype: str = "h5ad", train_dataset=[],
test_dataset=[], valid_dataset=[], data_dir="../temp_data"):
if train_dataset == ['84230ea4-998d-4aa8-8456-81dd54ce23af']:
pass

def find_dataset_in_metadata(datasets, tissue):
datasets_in_metadata = []
for dataset_id in datasets:
all_datasets = pd.read_csv(METADIR / "scdeepsort.csv", header=0, skiprows=[i for i in range(1, 68)])
for collect_dataset in all_datasets[all_datasets["tissue"] == tissue]["data_fname"].tolist():
if dataset_id in collect_dataset:
datasets_in_metadata.append(
(collect_dataset.split(tissue)[1] +
(tissue + collect_dataset.split(tissue)[2] if len(collect_dataset.split(tissue)) >= 3 else '')
).split('_')[0])
break
return datasets_in_metadata

train_dataset = find_dataset_in_metadata(train_dataset, tissue)
valid_dataset = find_dataset_in_metadata(valid_dataset, tissue)
test_dataset = find_dataset_in_metadata(test_dataset, tissue)
data = CellTypeAnnotationDataset(train_dataset=train_dataset, test_dataset=test_dataset,
valid_dataset=valid_dataset, data_dir=data_dir, tissue=tissue, species=species,
filetype=filetype).load_data()
Expand Down Expand Up @@ -89,6 +106,8 @@ def filter_gene(self, n_top_genes=3000):
Number of top variable genes to select
"""
sc.pp.filter_genes(self.origin_adata1, min_counts=3)
sc.pp.filter_genes(self.origin_adata2, min_counts=3)
sc.pp.highly_variable_genes(self.origin_adata1, n_top_genes=n_top_genes, flavor='seurat_v3')
sc.pp.highly_variable_genes(self.origin_adata2, n_top_genes=n_top_genes, flavor='seurat_v3')

Expand Down Expand Up @@ -195,6 +214,20 @@ def jsd(p, q):
similarity_matrix = 1 - divergence_matrix
return np.nanmean(similarity_matrix)

def compute_mmd_alternative(self) -> float:
X = self.X
Y = self.Y
gamma = 1.0
K_XX = rbf_kernel(X, X, gamma)
K_YY = rbf_kernel(Y, Y, gamma)
K_XY = rbf_kernel(X, Y, gamma)
n_x = X.shape[0]
n_y = Y.shape[0]
mmd = (K_XX.sum() - np.trace(K_XX)) / (n_x * (n_x - 1)) \
+ (K_YY.sum() - np.trace(K_YY)) / (n_y * (n_y - 1)) \
- 2 * K_XY.mean()
return 1 / (1 + np.sqrt(max(mmd, 0)))

def compute_mmd(self) -> float:
"""Compute Maximum Mean Discrepancy between datasets.
Expand Down Expand Up @@ -359,6 +392,12 @@ def get_dataset_info(data: ad.AnnData):
con_sim["n_measured_vars"] = np.mean(data.obs["n_measured_vars"])
con_sim["cell_num"] = len(data.obs)
con_sim["gene_num"] = len(data.var)
if "n_counts" not in data.obs.columns:
if scipy.sparse.issparse(data.X):
cell_counts = np.array(data.X.sum(axis=1)).flatten()
else:
cell_counts = data.X.sum(axis=1)
data.obs["n_counts"] = cell_counts
con_sim["n_counts_mean"] = np.mean(data.obs["n_counts"])
con_sim["n_counts_var"] = np.var(data.obs["n_counts"])
# if "n_counts" not in data.var.columns:
Expand Down Expand Up @@ -404,9 +443,9 @@ def get_targets(dataset_truth: str):
sim_targets = []
for method in self.methods:
query_dataset_truth = ground_truth_conf.loc[ground_truth_conf["dataset_id"] == self.adata1_name,
f"{method}_best_yaml"].iloc[0]
f"{method}_step2_best_yaml"].iloc[0]
atlas_dataset_truth = ground_truth_conf.loc[ground_truth_conf["dataset_id"] == self.adata2_name,
f"{method}_best_yaml"].iloc[0]
f"{method}_step2_best_yaml"].iloc[0]
if type(atlas_dataset_truth) == float and np.isnan(atlas_dataset_truth):
return 0
query_targets = get_targets(query_dataset_truth)
Expand Down
5 changes: 4 additions & 1 deletion dance/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def change_log_level(name: str = "dance", /, *, level: Union[str, int]):
DANCEDIR = Path(__file__).resolve().parents[1]
DANCEPKGDIR = DANCEDIR / "dance"
METADIR = DANCEPKGDIR / "metadata"

ATLASDIR = DANCEDIR / "examples/atlas"
SIMILARITYDIR = ATLASDIR / "sc_similarity_examples"
entity = "xzy11632"
project = "dance-dev"
__all__ = [
"change_log_level",
]
2 changes: 1 addition & 1 deletion examples/atlas/get_result_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

from dance import logger
from dance.settings import DANCEDIR, METADIR
from dance.settings import ATLASDIR, DANCEDIR, METADIR
from dance.utils import try_import

# get yaml of best method
Expand Down
Loading

0 comments on commit 8037b1c

Please sign in to comment.