Skip to content

Commit

Permalink
Add KernelExpression::get_cbc
Browse files Browse the repository at this point in the history
Add class method for computing cross-boundary correctness score.
  • Loading branch information
WeilerP committed Feb 23, 2024
1 parent 25123b2 commit d677cc2
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/cellrank/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,37 @@ def get_vector_field_estimate(self, rep: str):
extrapolated_gex = self.transition_matrix @ self.adata.obsm[rep]
return extrapolated_gex - self.adata.obsm[rep]

def get_cbc(self, source: str, target: str, cluster_key: str, rep: str, graph: str = "distances"):
"""Compute cross-boundary correctness score between source and target cluster."""

def get_pearson_corr(x, y):
x_centered = x - np.mean(x, axis=1).reshape(-1, 1)
if y.ndim == 1:
y_centered = y - np.mean(y)
pearson_corr = (
np.dot(x_centered, y_centered) / np.linalg.norm(x_centered, axis=1) / np.linalg.norm(y_centered)
)
else:
y_centered = y - np.mean(y, axis=1).reshape(-1, 1)
pearson_corr = (
np.sum(x_centered * y_centered, axis=1)
/ np.linalg.norm(x_centered, axis=1)
/ np.linalg.norm(y_centered, axis=1)
)
return pearson_corr

target_obs_mask = self._get_mask(series=self.adata.obs[cluster_key], subset=target)
boundary_ids = self.get_boundary(source=source, target=target, cluster_key=cluster_key, graph=graph)
empirical_velo = self.get_empirical_velocity_field(
boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph=graph
)
estimated_velo = self.get_vector_field_estimate(rep=rep)[boundary_ids, :]

cbc = get_pearson_corr(x=estimated_velo, y=empirical_velo)
if hasattr(self, cbc):
self.cbc[(source, target)] = cbc
return cbc


@d.dedent
class Kernel(KernelExpression, abc.ABC):
Expand Down

0 comments on commit d677cc2

Please sign in to comment.