Skip to content

Commit

Permalink
updated and added documentation for grasp and boss respectively
Browse files Browse the repository at this point in the history
  • Loading branch information
bja43 committed Nov 23, 2024
1 parent 5399515 commit d17f4c5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 20 deletions.
16 changes: 8 additions & 8 deletions causallearn/search/PermutationBased/BOSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def boss(
if n < p:
warnings.warn("The number of features is much larger than the sample size!")

if score_func == "local_score_CV_general":
if score_func == "local_score_CV_general":
# % k-fold negative cross validated likelihood based on regression in RKHS
if parameters is None:
parameters = {
Expand All @@ -63,13 +63,13 @@ def boss(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_cv_general, parameters=parameters
)
elif score_func == "local_score_marginal_general":
elif score_func == "local_score_marginal_general":
# negative marginal likelihood based on regression in RKHS
parameters = {}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
)
elif score_func == "local_score_CV_multi":
elif score_func == "local_score_CV_multi":
# k-fold negative cross validated likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -83,7 +83,7 @@ def boss(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
)
elif score_func == "local_score_marginal_multi":
elif score_func == "local_score_marginal_multi":
# negative marginal likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -93,22 +93,22 @@ def boss(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
)
elif score_func == "local_score_BIC":
elif score_func == "local_score_BIC":
# SEM BIC score
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC, parameters=parameters
)
elif score_func == "local_score_BIC_from_cov":
elif score_func == "local_score_BIC_from_cov":
# SEM BIC score
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
)
elif score_func == "local_score_BDeu":
elif score_func == "local_score_BDeu":
# BDeu score
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BDeu, parameters=None
Expand Down Expand Up @@ -204,4 +204,4 @@ def better_mutation(v, order, gsts):
order.remove(v)
order.insert(best - int(best > i), v)

return True
return True
16 changes: 8 additions & 8 deletions causallearn/search/PermutationBased/GRaSP.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
local_score_marginal_general,
local_score_marginal_multi,
)
from causallearn.search.PermutationBased.gst import GST;
from causallearn.search.PermutationBased.gst import GST
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
from causallearn.utils.DAG2CPDAG import dag2cpdag

Expand Down Expand Up @@ -111,7 +111,7 @@ def grasp(
if n < p:
warnings.warn("The number of features is much larger than the sample size!")

if score_func == "local_score_CV_general":
if score_func == "local_score_CV_general":
# k-fold negative cross validated likelihood based on regression in RKHS
if parameters is None:
parameters = {
Expand All @@ -127,7 +127,7 @@ def grasp(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
)
elif score_func == "local_score_CV_multi":
elif score_func == "local_score_CV_multi":
# k-fold negative cross validated likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -141,7 +141,7 @@ def grasp(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
)
elif score_func == "local_score_marginal_multi":
elif score_func == "local_score_marginal_multi":
# negative marginal likelihood based on regression in RKHS
# for data with multi-variate dimensions
if parameters is None:
Expand All @@ -151,22 +151,22 @@ def grasp(
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
)
elif score_func == "local_score_BIC":
elif score_func == "local_score_BIC":
# SEM BIC score
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC, parameters=parameters
)
elif score_func == "local_score_BIC_from_cov":
elif score_func == "local_score_BIC_from_cov":
# SEM BIC score
if parameters is None:
parameters = {"lambda_value": 2}
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
)
elif score_func == "local_score_BDeu":
elif score_func == "local_score_BDeu":
# BDeu score
localScoreClass = LocalScoreClass(
data=X, local_score_fun=local_score_BDeu, parameters=None
Expand Down Expand Up @@ -204,7 +204,7 @@ def grasp(
sys.stdout.flush()

runtime = time.perf_counter() - runtime

if verbose:
sys.stdout.write("\nGRaSP completed in: %.2fs \n" % runtime)
sys.stdout.flush()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ GRaSP
Algorithm Introduction
--------------------------------------

Greedy relaxation of the sparsest permutation (GRaSP) algorithm [1]_.
Greedy relaxations of the sparsest permutation (GRaSP) algorithm [1]_.


Usage
Expand All @@ -19,7 +19,7 @@ Usage
G = grasp(X)
# or customized parameters
G = grasp(X, score_func, depth, maxP, parameters)
G = grasp(X, score_func, depth, parameters)
# Visualization using pydot
from causallearn.utils.GraphUtils import GraphUtils
Expand Down Expand Up @@ -50,8 +50,6 @@ and n_features is the number of features.
- ":ref:`local_score_CV_multi <Generalized score with cross validation>`": Generalized score with cross validation for data with multi-dimensional variables [2]_.
- ":ref:`local_score_marginal_multi <Generalized score with marginal likelihood>`": Generalized score with marginal likelihood for data with multi-dimensional variables [2]_.

**maxP**: Allowed maximum number of parents when searching the graph. Default: None.

**parameters**: Needed when using CV likelihood. Default: None.
- parameters['kfold']: k-fold cross validation.
- parameters['lambda']: regularization parameter.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
.. _GRaSP:

GRaSP
==============================================

Algorithm Introduction
--------------------------------------

Best order score search (BOSS) algorithm [1]_.


Usage
----------------------------
.. code-block:: python
from causallearn.search.PermutationBased.BOSS import boss
# default parameters
G = boss(X)
# or customized parameters
G = boss(X, score_func, parameters)
# Visualization using pydot
from causallearn.utils.GraphUtils import GraphUtils
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import io
pyd = GraphUtils.to_pydot(G)
tmp_png = pyd.create_png(f="png")
fp = io.BytesIO(tmp_png)
img = mpimg.imread(fp, format='png')
plt.axis('off')
plt.imshow(img)
plt.show()
Visualization using pydot is recommended (`usage example <https://github.com/cmu-phil/causal-learn/blob/main/tests/TestBOSS.py>`_). If specific label names are needed, please refer to this `usage example <https://github.com/cmu-phil/causal-learn/blob/e4e73f8b58510a3cd5a9125ba50c0ac62a425ef3/tests/TestGraphVisualization.py#L106>`_ (e.g., GraphUtils.to_pydot(G, labels=["A", "B", "C"]).

Parameters
-------------------
**X**: numpy.ndarray, shape (n_samples, n_features). Data, where n_samples is the number of samples
and n_features is the number of features.

**score_func**: The score function you would like to use, including (see :ref:`score_functions`.). Default: 'local_score_BIC'.
- ":ref:`local_score_BIC <BIC score>`": BIC score [3]_.
- ":ref:`local_score_BDeu <BDeu score>`": BDeu score [4]_.
- ":ref:`local_score_CV_general <Generalized score with cross validation>`": Generalized score with cross validation for data with single-dimensional variables [2]_.
- ":ref:`local_score_marginal_general <Generalized score with marginal likelihood>`": Generalized score with marginal likelihood for data with single-dimensional variables [2]_.
- ":ref:`local_score_CV_multi <Generalized score with cross validation>`": Generalized score with cross validation for data with multi-dimensional variables [2]_.
- ":ref:`local_score_marginal_multi <Generalized score with marginal likelihood>`": Generalized score with marginal likelihood for data with multi-dimensional variables [2]_.

**parameters**: Needed when using CV likelihood. Default: None.
- parameters['kfold']: k-fold cross validation.
- parameters['lambda']: regularization parameter.
- parameters['dlabel']: for variables with multi-dimensions, indicate which dimensions belong to the i-th variable.



Returns
-------------------
- **G**: learned general graph, where G.graph[j,i]=1 and G.graph[i,j]=-1 indicate i --> j; G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.


.. [1] Andrews, B., Ramsey, J., Sanchez Romero, R., Camchong, J., & Kummerfeld, E. (2023). Fast scalable and accurate discovery of dags using the best order score search and grow shrink trees. Advances in Neural Information Processing Systems, 36, 63945-63956.
.. [2] Huang, B., Zhang, K., Lin, Y., Schölkopf, B., & Glymour, C. (2018, July). Generalized score functions for causal discovery. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (pp. 1551-1560).
.. [3] Schwarz, G. (1978). Estimating the dimension of a model. The annals of statistics, 461-464.
.. [4] Buntine, W. (1991). Theory refinement on Bayesian networks. In Uncertainty proceedings 1991 (pp. 52-60). Morgan Kaufmann.

0 comments on commit d17f4c5

Please sign in to comment.