Skip to content

Commit

Permalink
Add extensions and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pablormier committed Jul 1, 2024
1 parent adcdc41 commit f977df0
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 6 deletions.
14 changes: 8 additions & 6 deletions corneto/_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import pickle
from collections import OrderedDict
from copy import deepcopy
from enum import Enum
Expand All @@ -23,7 +24,7 @@

from corneto._io import import_cobra_model
from corneto._types import CobraModel, Edge, NxDiGraph, NxGraph
from corneto._util import unique_iter
from corneto._util import obj_content_hash, unique_iter
from corneto.utils import Attr, Attributes

T = TypeVar("T")
Expand Down Expand Up @@ -175,6 +176,9 @@ def num_vertices(self) -> int:
def num_edges(self) -> int:
return self._num_edges()

def hash(self) -> str:
return obj_content_hash(self)

def get_attr_edge(self, index: int) -> Attributes:
return self._get_edge_attributes(index)

Expand Down Expand Up @@ -357,13 +361,15 @@ def add_edge(
self,
source: Union[Any, Iterable[Any]],
target: Union[Any, Iterable[Any]],
type: EdgeType = EdgeType.DIRECTED,
type: Optional[EdgeType] = EdgeType.DIRECTED,
edge_source_attr: Optional[Attributes] = None,
edge_target_attr: Optional[Attributes] = None,
**kwargs,
) -> int:
# In self loops, or hyperedges with partial self loops
# important to have dupl. vertices!. E.g {A: -1, A: 1} A (-1)->(1) A
if type is None:
type = self._default_edge_type
ve_s = Graph._extract_ve_attr(source) # {vertex: value}
ve_t = Graph._extract_ve_attr(target)
if edge_source_attr is None:
Expand Down Expand Up @@ -582,8 +588,6 @@ def from_vertex_incidence(
return g

def save(self, filename: str, compressed: Optional[bool] = True) -> None:
import pickle

if not filename:
raise ValueError("Filename must not be empty.")

Expand All @@ -603,8 +607,6 @@ def save(self, filename: str, compressed: Optional[bool] = True) -> None:

@staticmethod
def load(filename: str) -> "BaseGraph":
import pickle

if filename.endswith(".gz"):
import gzip

Expand Down
9 changes: 9 additions & 0 deletions corneto/_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import pickle
from collections import OrderedDict
from itertools import filterfalse
from typing import Any, Callable, Dict, Iterable, Optional, Set, TypeVar
Expand All @@ -8,6 +10,13 @@
T = TypeVar("T")


def obj_content_hash(obj) -> str:
obj_serialized = pickle.dumps(obj)
hash_obj = hashlib.sha256()
hash_obj.update(obj_serialized)
return hash_obj.hexdigest()


def unique_iter(
iterable: Iterable[T], key: Optional[Callable[[T], Any]] = None
) -> Iterable[T]:
Expand Down
3 changes: 3 additions & 0 deletions corneto/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from corneto.extensions._numba import OptionalNumba

numba = OptionalNumba()
16 changes: 16 additions & 0 deletions corneto/extensions/_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from corneto.extensions._optional import OptionalModule


class OptionalNumba(OptionalModule):
def __init__(self):
super().__init__("numba")

def _create_dummy(self, name): # type: ignore
if name in ["uint16", "uint32", "uint64", "int16", "int32", "int64"]:
return int
if name in ["float32", "float64", "complex64", "complex128"]:
return float
if name == "prange":
return range

return super()._create_dummy(name)
33 changes: 33 additions & 0 deletions corneto/extensions/_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import importlib
from functools import wraps

from corneto._settings import LOGGER


class OptionalModule:
def __init__(self, module_name):
self.module_name = module_name
self.module = None
try:
self.module = importlib.import_module(module_name)
except ImportError:
LOGGER.debug(f"Optional module {module_name} not found.")

def __getattr__(self, item):
if self.module:
return getattr(self.module, item)
else:
return self._create_dummy(item)

def _create_dummy(self, name):
def _dummy(*args, **kwargs):
def _decorator(func):
@wraps(func)
def _wrapped_func(*_args, **_kwargs):
return func(*_args, **_kwargs)

return _wrapped_func

return _decorator

return _dummy
56 changes: 56 additions & 0 deletions corneto/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,66 @@
import importlib
import re
from functools import wraps
from pathlib import Path

from corneto.utils._attr import Attr, Attributes


class OptionalModule:
def __init__(self, module_name):
self.module_name = module_name
self.module = None
try:
self.module = importlib.import_module(module_name)
except ImportError:
pass

def __getattr__(self, item):
if self.module:
return getattr(self.module, item)
else:
return self._create_dummy(item)

def _create_dummy(self, name):
def dummy(*args, **kwargs):
return None

return dummy


class OptionalNumba(OptionalModule):
def __init__(self):
super().__init__("numba")

def _create_dummy(self, name):
if name == "jit":

def _jit(*_args, **_kwargs):
def _dummy_jit(func):
@wraps(func)
def _wrapped_func(*args, **kwargs):
return func(*args, **kwargs)

return _wrapped_func

return _dummy_jit

return _jit

if name in ["uint16", "uint32", "uint64", "int16", "int32", "int64"]:
return int
if name in ["float32", "float64", "complex64", "complex128"]:
return float
if name == "prange":
return range

return super()._create_dummy(name)


# Create an instance for numba
numba = OptionalModule("numba")


def get_library_version(lib_name):
pyproject_path = Path(__file__).resolve().parent.parent.parent / "pyproject.toml"
try:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np

from corneto.extensions import numba


@numba.jit(nopython=True, fastmath=True, parallel=True)
def _loop(x):
r = np.empty_like(x)
n = len(x)
for i in range(n):
r[i] = np.cos(x[i]) ** 2 + np.sin(x[i]) ** 2
return r


@numba.vectorize(["float64(float64, float64)"], target="parallel")
def vec_sum(x, y):
return x + y


@numba.guvectorize(["void(float64[:], intp[:], float64[:])"], "(n),()->(n)")
def move_mean(a, window_arr, out):
window_width = window_arr[0]
asum = 0.0
count = 0
for i in range(window_width):
asum += a[i]
count += 1
out[i] = asum / count
for i in range(window_width, len(a)):
asum += a[i] - a[i - window_width]
out[i] = asum / count


def test_guvectorize_numba():
arr = np.arange(20, dtype=np.float64).reshape(2, 10)
result = move_mean(arr, 3)
expected = np.array(
[
[0.0, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
[10.0, 10.5, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0],
]
)
np.testing.assert_allclose(result, expected)


def test_jit_loop_numba():
_loop(np.ones(10000))
assert True


def test_vec_sum_numba():
vec_sum(np.ones(10000), np.ones(10000))
assert True
25 changes: 25 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,28 @@ def test_prune_directed():
type=EdgeType.DIRECTED,
)
assert set(G.prune(["E"], ["K"]).V) == {"E", "F", "H", "K"}


def test_graph_hash():
G = Graph()
G.add_edges(
[
("A", "B"),
("A", "C"),
("A", "D"),
("D", "C"),
("D", "E"),
("B", "E"),
("E", "F"),
("A", "F"),
],
type=EdgeType.DIRECTED,
)
h1 = G.hash()
G.add_edge("F", "G")
h2 = G.hash()
G._edge_attr[0]["x"] = ""
h3 = G.hash()
assert h1 != h2
assert h1 != h3
assert h2 != h3

0 comments on commit f977df0

Please sign in to comment.