Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CBC score #1168

Merged
merged 13 commits into from
Mar 4, 2024
125 changes: 125 additions & 0 deletions src/cellrank/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,131 @@ def _reuse_cache(self, expected_params: Dict[str, Any], *, time: Optional[Any] =
self._params = expected_params
# fmt: on

def _get_boundary(self, source: str, target: str, cluster_key: str, graph_key: str = "distances") -> List[int]:
"""Identify source observations at boundary to target cluster.

Parameters
----------
source
Name of source cluster.
target
Name of target cluster.
cluster_key
Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.

Returns
-------
List of observation IDs at boundary to target cluster.
"""
source_obs_mask = self.adata.obs[cluster_key].isin([source] if isinstance(source, str) else source)
target_obs_mask = self.adata.obs[cluster_key].isin([target] if isinstance(target, str) else target)

source_ids = np.where(source_obs_mask)[0]
boundary_ids = []

graph = self.adata.obsp[graph_key]
for source_id in source_ids:
obs_mask = graph[source_id, :].toarray().squeeze().astype(bool)

if (obs_mask & target_obs_mask).any():
boundary_ids.append(source_id)

return boundary_ids

def _get_empirical_velocity_field(
self, boundary_ids: List[int], target_obs_mask, rep: str, graph_key: str = "distances"
) -> np.ndarray:
"""Compute an emprical estimate of velocity field between two clusters.

Parameters
----------
boundary_ids
List of observation IDs at boundary to target cluster.
target_obs_mask
Boolean indicator identifying relevant observations from target.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.

Returns
-------
Empirical velocity estimate.
"""
obs_ids = np.arange(0, self.adata.n_obs)
graph = self.adata.obsp[graph_key]
features = self.adata.obsm[rep]
empirical_velo = np.empty(shape=(len(boundary_ids), features.shape[1]))

for idx, boundary_id in enumerate(boundary_ids):
row = graph[boundary_id, :].toarray().squeeze()
obs_mask = row.astype(bool) & target_obs_mask
neighbors = obs_ids[obs_mask]
weights = row[obs_mask]

empirical_velo[idx, :] = np.sum(
weights.reshape(-1, 1) * (features[neighbors, :] - features[boundary_id, :]), axis=0
)

empirical_velo = np.array(empirical_velo)
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
obs_mask = np.isnan(empirical_velo).any(axis=1)
empirical_velo = empirical_velo[~obs_mask, :]

return empirical_velo

def _get_vector_field_estimate(self, rep: str) -> np.ndarray:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""Compute estimate of vector field under one step of the transition matrix.

Parameters
----------
rep
Key in :attr:`~anndata.AnnData.obsm` to use as data representation.

Returns
-------
Vector field estimate based on kernel dynamics.
"""
extrapolated_gex = self.transition_matrix @ self.adata.obsm[rep]
return extrapolated_gex - self.adata.obsm[rep]

# TODO: Add definition/reference to paper
def cbc(self, source: str, target: str, cluster_key: str, rep: str, graph_key: str = "distances") -> np.ndarray:
"""Compute cross-boundary correctness score between source and target cluster.

WeilerP marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
source
Name of the source cluster.
target
Name of the target cluster.
cluster_key
Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
rep
Key in :attr:`~anndata.AnnData.obsm` to use as data representation.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.

Returns
-------
Cross-boundary correctness score for each observation.
"""

def _pearsonr(x: np.ndarray, y: np.ndarray) -> np.ndarray:
x_centered = x - np.mean(x, axis=1, keepdims=True)
y_centered = y - np.mean(y, axis=1, keepdims=True)
denom = np.linalg.norm(x_centered, axis=1) * np.linalg.norm(y_centered, axis=1)

return np.sum(x_centered * y_centered, axis=1) / denom

target_obs_mask = self.adata.obs[cluster_key].isin([target])
boundary_ids = self._get_boundary(source=source, target=target, cluster_key=cluster_key, graph_key=graph_key)
empirical_velo = self._get_empirical_velocity_field(
boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph_key=graph_key
)
estimated_velo = self._get_vector_field_estimate(rep=rep)[boundary_ids, :]

return _pearsonr(x=estimated_velo, y=empirical_velo)


@d.dedent
class Kernel(KernelExpression, abc.ABC):
Expand Down
Binary file modified tests/_ground_truth_adatas/adata_50.h5ad
Binary file not shown.
22 changes: 22 additions & 0 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,28 @@ def test_connectivities_key_kernel(self, adata: AnnData):
assert T_cr is not adata.obsp[key]
np.testing.assert_array_equal(T_cr.A, adata.obsp[key])

@pytest.mark.parametrize("cluster_pair", [("Granule immature", "Granule mature"), ("nIPC", "Neuroblast")])
@pytest.mark.parametrize("graph_key", ["distances", "connectivities"])
def test_cbc(self, adata: AnnData, cluster_pair: Tuple[str, str], graph_key: str):
cluster_key = "clusters"
rep = "X_pca"
source, target = cluster_pair

vk = cr.kernels.VelocityKernel(adata)
vk.compute_transition_matrix()

ck = cr.kernels.ConnectivityKernel(adata)
ck.compute_transition_matrix()
combined_kernel = 0.8 * vk + 0.2 * ck

cbc_vk = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)
np.testing.assert_almost_equal(cbc_vk, adata.uns["cbc"][f"{source}-{target}-{graph_key}-vk"])

cbc_combined_kernel = combined_kernel.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)
np.testing.assert_almost_equal(
cbc_combined_kernel, adata.uns["cbc"][f"{source}-{target}-{graph_key}-0.8vk+0.2ck"]
)


class TestVelocityKernelReadData:
@pytest.mark.parametrize("attr", ["layers", "obsm"])
Expand Down
Loading