From 5a18297d2bdfee021df7b8d6eb0658777a0766e3 Mon Sep 17 00:00:00 2001 From: yinghuan Date: Wed, 13 Nov 2024 17:51:28 -0500 Subject: [PATCH] Graph operations compatible with np array Signed-off-by: Endre Moen --- causallearn/graph/Dag.py | 21 ++++++++++++++++++--- causallearn/graph/Node.py | 19 +++++++++++++------ causallearn/utils/DAG2CPDAG.py | 11 +++++++++-- 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/causallearn/graph/Dag.py b/causallearn/graph/Dag.py index 1f549ff6..7df8cc32 100644 --- a/causallearn/graph/Dag.py +++ b/causallearn/graph/Dag.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 from itertools import combinations -from typing import List +from typing import List, Optional, Union import networkx as nx import numpy as np @@ -18,8 +18,23 @@ # or latent, with at most one edge per node pair, and no edges to self. class Dag(GeneralGraph): - def __init__(self, nodes: List[Node]): - + def __init__(self, nodes: Optional[List[Node]]=None, graph: Union[np.ndarray, nx.Graph, None]=None): + if nodes is not None: + self._init_from_nodes(nodes) + elif graph is not None: + if isinstance(graph, np.ndarray): + nodes = [Node(node_name=str(i)) for i in range(len(graph))] + self._init_from_nodes(nodes) + for i in range(len(nodes)): + for j in range(len(nodes)): + if graph[i, j] == 1: + self.add_directed_edge(nodes[i], nodes[j]) + else: + pass + else: + raise ValueError("Dag.__init__() requires argument 'nodes' or 'graph'") + + def _init_from_nodes(self, nodes: List[Node]): # for node in nodes: # if not isinstance(node, type(GraphNode)): # raise TypeError("Graphs must be instantiated with a list of GraphNodes") diff --git a/causallearn/graph/Node.py b/causallearn/graph/Node.py index 3b9b3e2a..798a616f 100644 --- a/causallearn/graph/Node.py +++ b/causallearn/graph/Node.py @@ -2,27 +2,34 @@ # Represents an object with a name, node type, and position that can serve as a # node in a graph. +from typing import Optional from causallearn.graph.NodeType import NodeType from causallearn.graph.NodeVariableType import NodeVariableType class Node: + node_type: NodeType + node_name: str + def __init__(self, node_name: Optional[str] = None, node_type: Optional[NodeType] = None) -> None: + self.node_name = node_name + self.node_type = node_type + # @return the name of the variable. def get_name(self) -> str: - pass + return self.node_name # set the name of the variable def set_name(self, name: str): - pass + self.node_name = name # @return the node type of the variable def get_node_type(self) -> NodeType: - pass + return self.node_type # set the node type of the variable def set_node_type(self, node_type: NodeType): - pass + self.node_type = node_type # @return the intervention type def get_node_variable_type(self) -> NodeVariableType: @@ -35,7 +42,7 @@ def set_node_variable_type(self, var_type: NodeVariableType): # @return the name of the node as its string representation def __str__(self): - pass + return self.node_name # @return the x coordinate of the center of the node def get_center_x(self) -> int: @@ -59,7 +66,7 @@ def set_center(self, center_x: int, center_y: int): # @return a hashcode for this variable def __hash__(self): - pass + return hash(self.node_name) # @return true iff this variable is equal to the given variable def __eq__(self, other): diff --git a/causallearn/utils/DAG2CPDAG.py b/causallearn/utils/DAG2CPDAG.py index 2604b1f7..01d2f2ee 100644 --- a/causallearn/utils/DAG2CPDAG.py +++ b/causallearn/utils/DAG2CPDAG.py @@ -1,3 +1,4 @@ +from typing import Union import numpy as np from causallearn.graph.Dag import Dag @@ -6,7 +7,7 @@ from causallearn.graph.GeneralGraph import GeneralGraph -def dag2cpdag(G: Dag) -> GeneralGraph: +def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph: """ Convert a DAG to its corresponding PDAG @@ -22,7 +23,13 @@ def dag2cpdag(G: Dag) -> GeneralGraph: ------- Yuequn Liu@dmirlab, Wei Chen@dmirlab, Kun Zhang@CMU """ - + + if isinstance(G, np.ndarray): + # convert np array to Dag graph + G = Dag(graph=G) + elif not isinstance(G, Dag): + raise TypeError("parameter graph should be `Dag` or `np.ndarry`") + # order the edges in G nodes_order = list( map(lambda x: G.node_map[x], G.get_causal_ordering())