Skip to content

Commit

Permalink
Merge pull request #206 from sherryzyh/main
Browse files Browse the repository at this point in the history
Graph operations compatible with np array
  • Loading branch information
kunwuz authored Dec 20, 2024
2 parents 08655cd + 705bcec commit 8c72b52
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
21 changes: 18 additions & 3 deletions causallearn/graph/Dag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from itertools import combinations
from typing import List
from typing import List, Optional, Union

import networkx as nx
import numpy as np
Expand All @@ -18,8 +18,23 @@
# or latent, with at most one edge per node pair, and no edges to self.
class Dag(GeneralGraph):

def __init__(self, nodes: List[Node]):

def __init__(self, nodes: Optional[List[Node]]=None, graph: Union[np.ndarray, nx.Graph, None]=None):
if nodes is not None:
self._init_from_nodes(nodes)
elif graph is not None:
if isinstance(graph, np.ndarray):
nodes = [Node(node_name=str(i)) for i in range(len(graph))]
self._init_from_nodes(nodes)
for i in range(len(nodes)):
for j in range(len(nodes)):
if graph[i, j] == 1:
self.add_directed_edge(nodes[i], nodes[j])
else:
pass
else:
raise ValueError("Dag.__init__() requires argument 'nodes' or 'graph'")

def _init_from_nodes(self, nodes: List[Node]):
# for node in nodes:
# if not isinstance(node, type(GraphNode)):
# raise TypeError("Graphs must be instantiated with a list of GraphNodes")
Expand Down
19 changes: 13 additions & 6 deletions causallearn/graph/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,34 @@

# Represents an object with a name, node type, and position that can serve as a
# node in a graph.
from typing import Optional
from causallearn.graph.NodeType import NodeType
from causallearn.graph.NodeVariableType import NodeVariableType


class Node:
node_type: NodeType
node_name: str

def __init__(self, node_name: Optional[str] = None, node_type: Optional[NodeType] = None) -> None:
self.node_name = node_name
self.node_type = node_type

# @return the name of the variable.
def get_name(self) -> str:
pass
return self.node_name

# set the name of the variable
def set_name(self, name: str):
pass
self.node_name = name

# @return the node type of the variable
def get_node_type(self) -> NodeType:
pass
return self.node_type

# set the node type of the variable
def set_node_type(self, node_type: NodeType):
pass
self.node_type = node_type

# @return the intervention type
def get_node_variable_type(self) -> NodeVariableType:
Expand All @@ -35,7 +42,7 @@ def set_node_variable_type(self, var_type: NodeVariableType):

# @return the name of the node as its string representation
def __str__(self):
pass
return self.node_name

# @return the x coordinate of the center of the node
def get_center_x(self) -> int:
Expand All @@ -59,7 +66,7 @@ def set_center(self, center_x: int, center_y: int):

# @return a hashcode for this variable
def __hash__(self):
pass
return hash(self.node_name)

# @return true iff this variable is equal to the given variable
def __eq__(self, other):
Expand Down
11 changes: 9 additions & 2 deletions causallearn/utils/DAG2CPDAG.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
import numpy as np

from causallearn.graph.Dag import Dag
Expand All @@ -6,7 +7,7 @@
from causallearn.graph.GeneralGraph import GeneralGraph


def dag2cpdag(G: Dag) -> GeneralGraph:
def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph:
"""
Convert a DAG to its corresponding PDAG
Expand All @@ -22,7 +23,13 @@ def dag2cpdag(G: Dag) -> GeneralGraph:
-------
Yuequn Liu@dmirlab, Wei Chen@dmirlab, Kun Zhang@CMU
"""


if isinstance(G, np.ndarray):
# convert np array to Dag graph
G = Dag(graph=G)
elif not isinstance(G, Dag):
raise TypeError("parameter graph should be `Dag` or `np.ndarry`")

# order the edges in G
nodes_order = list(
map(lambda x: G.node_map[x], G.get_causal_ordering())
Expand Down

0 comments on commit 8c72b52

Please sign in to comment.