From f977df015e71de3fa4adb8ecf26d577acfa018e8 Mon Sep 17 00:00:00 2001 From: "Pablo R. Mier" Date: Mon, 1 Jul 2024 11:42:44 +0200 Subject: [PATCH] Add extensions and tests --- corneto/_graph.py | 14 +++++---- corneto/_util.py | 9 ++++++ corneto/extensions/__init__.py | 3 ++ corneto/extensions/_numba.py | 16 ++++++++++ corneto/extensions/_optional.py | 33 +++++++++++++++++++ corneto/utils/__init__.py | 56 +++++++++++++++++++++++++++++++++ tests/test_extensions.py | 53 +++++++++++++++++++++++++++++++ tests/test_graph.py | 25 +++++++++++++++ 8 files changed, 203 insertions(+), 6 deletions(-) create mode 100644 corneto/extensions/__init__.py create mode 100644 corneto/extensions/_numba.py create mode 100644 corneto/extensions/_optional.py create mode 100644 tests/test_extensions.py diff --git a/corneto/_graph.py b/corneto/_graph.py index 7c2ce84c..9d0c5cc1 100644 --- a/corneto/_graph.py +++ b/corneto/_graph.py @@ -1,4 +1,5 @@ import abc +import pickle from collections import OrderedDict from copy import deepcopy from enum import Enum @@ -23,7 +24,7 @@ from corneto._io import import_cobra_model from corneto._types import CobraModel, Edge, NxDiGraph, NxGraph -from corneto._util import unique_iter +from corneto._util import obj_content_hash, unique_iter from corneto.utils import Attr, Attributes T = TypeVar("T") @@ -175,6 +176,9 @@ def num_vertices(self) -> int: def num_edges(self) -> int: return self._num_edges() + def hash(self) -> str: + return obj_content_hash(self) + def get_attr_edge(self, index: int) -> Attributes: return self._get_edge_attributes(index) @@ -357,13 +361,15 @@ def add_edge( self, source: Union[Any, Iterable[Any]], target: Union[Any, Iterable[Any]], - type: EdgeType = EdgeType.DIRECTED, + type: Optional[EdgeType] = EdgeType.DIRECTED, edge_source_attr: Optional[Attributes] = None, edge_target_attr: Optional[Attributes] = None, **kwargs, ) -> int: # In self loops, or hyperedges with partial self loops # important to have dupl. vertices!. E.g {A: -1, A: 1} A (-1)->(1) A + if type is None: + type = self._default_edge_type ve_s = Graph._extract_ve_attr(source) # {vertex: value} ve_t = Graph._extract_ve_attr(target) if edge_source_attr is None: @@ -582,8 +588,6 @@ def from_vertex_incidence( return g def save(self, filename: str, compressed: Optional[bool] = True) -> None: - import pickle - if not filename: raise ValueError("Filename must not be empty.") @@ -603,8 +607,6 @@ def save(self, filename: str, compressed: Optional[bool] = True) -> None: @staticmethod def load(filename: str) -> "BaseGraph": - import pickle - if filename.endswith(".gz"): import gzip diff --git a/corneto/_util.py b/corneto/_util.py index 8af34fba..41a50278 100644 --- a/corneto/_util.py +++ b/corneto/_util.py @@ -1,3 +1,5 @@ +import hashlib +import pickle from collections import OrderedDict from itertools import filterfalse from typing import Any, Callable, Dict, Iterable, Optional, Set, TypeVar @@ -8,6 +10,13 @@ T = TypeVar("T") +def obj_content_hash(obj) -> str: + obj_serialized = pickle.dumps(obj) + hash_obj = hashlib.sha256() + hash_obj.update(obj_serialized) + return hash_obj.hexdigest() + + def unique_iter( iterable: Iterable[T], key: Optional[Callable[[T], Any]] = None ) -> Iterable[T]: diff --git a/corneto/extensions/__init__.py b/corneto/extensions/__init__.py new file mode 100644 index 00000000..285af1ba --- /dev/null +++ b/corneto/extensions/__init__.py @@ -0,0 +1,3 @@ +from corneto.extensions._numba import OptionalNumba + +numba = OptionalNumba() diff --git a/corneto/extensions/_numba.py b/corneto/extensions/_numba.py new file mode 100644 index 00000000..fa043541 --- /dev/null +++ b/corneto/extensions/_numba.py @@ -0,0 +1,16 @@ +from corneto.extensions._optional import OptionalModule + + +class OptionalNumba(OptionalModule): + def __init__(self): + super().__init__("numba") + + def _create_dummy(self, name): # type: ignore + if name in ["uint16", "uint32", "uint64", "int16", "int32", "int64"]: + return int + if name in ["float32", "float64", "complex64", "complex128"]: + return float + if name == "prange": + return range + + return super()._create_dummy(name) diff --git a/corneto/extensions/_optional.py b/corneto/extensions/_optional.py new file mode 100644 index 00000000..04d73991 --- /dev/null +++ b/corneto/extensions/_optional.py @@ -0,0 +1,33 @@ +import importlib +from functools import wraps + +from corneto._settings import LOGGER + + +class OptionalModule: + def __init__(self, module_name): + self.module_name = module_name + self.module = None + try: + self.module = importlib.import_module(module_name) + except ImportError: + LOGGER.debug(f"Optional module {module_name} not found.") + + def __getattr__(self, item): + if self.module: + return getattr(self.module, item) + else: + return self._create_dummy(item) + + def _create_dummy(self, name): + def _dummy(*args, **kwargs): + def _decorator(func): + @wraps(func) + def _wrapped_func(*_args, **_kwargs): + return func(*_args, **_kwargs) + + return _wrapped_func + + return _decorator + + return _dummy diff --git a/corneto/utils/__init__.py b/corneto/utils/__init__.py index 1dc92fb9..ee5efdbf 100644 --- a/corneto/utils/__init__.py +++ b/corneto/utils/__init__.py @@ -1,10 +1,66 @@ import importlib import re +from functools import wraps from pathlib import Path from corneto.utils._attr import Attr, Attributes +class OptionalModule: + def __init__(self, module_name): + self.module_name = module_name + self.module = None + try: + self.module = importlib.import_module(module_name) + except ImportError: + pass + + def __getattr__(self, item): + if self.module: + return getattr(self.module, item) + else: + return self._create_dummy(item) + + def _create_dummy(self, name): + def dummy(*args, **kwargs): + return None + + return dummy + + +class OptionalNumba(OptionalModule): + def __init__(self): + super().__init__("numba") + + def _create_dummy(self, name): + if name == "jit": + + def _jit(*_args, **_kwargs): + def _dummy_jit(func): + @wraps(func) + def _wrapped_func(*args, **kwargs): + return func(*args, **kwargs) + + return _wrapped_func + + return _dummy_jit + + return _jit + + if name in ["uint16", "uint32", "uint64", "int16", "int32", "int64"]: + return int + if name in ["float32", "float64", "complex64", "complex128"]: + return float + if name == "prange": + return range + + return super()._create_dummy(name) + + +# Create an instance for numba +numba = OptionalModule("numba") + + def get_library_version(lib_name): pyproject_path = Path(__file__).resolve().parent.parent.parent / "pyproject.toml" try: diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 00000000..a63bc3af --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,53 @@ +import numpy as np + +from corneto.extensions import numba + + +@numba.jit(nopython=True, fastmath=True, parallel=True) +def _loop(x): + r = np.empty_like(x) + n = len(x) + for i in range(n): + r[i] = np.cos(x[i]) ** 2 + np.sin(x[i]) ** 2 + return r + + +@numba.vectorize(["float64(float64, float64)"], target="parallel") +def vec_sum(x, y): + return x + y + + +@numba.guvectorize(["void(float64[:], intp[:], float64[:])"], "(n),()->(n)") +def move_mean(a, window_arr, out): + window_width = window_arr[0] + asum = 0.0 + count = 0 + for i in range(window_width): + asum += a[i] + count += 1 + out[i] = asum / count + for i in range(window_width, len(a)): + asum += a[i] - a[i - window_width] + out[i] = asum / count + + +def test_guvectorize_numba(): + arr = np.arange(20, dtype=np.float64).reshape(2, 10) + result = move_mean(arr, 3) + expected = np.array( + [ + [0.0, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + [10.0, 10.5, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0], + ] + ) + np.testing.assert_allclose(result, expected) + + +def test_jit_loop_numba(): + _loop(np.ones(10000)) + assert True + + +def test_vec_sum_numba(): + vec_sum(np.ones(10000), np.ones(10000)) + assert True diff --git a/tests/test_graph.py b/tests/test_graph.py index df2a236c..c4fdb382 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -364,3 +364,28 @@ def test_prune_directed(): type=EdgeType.DIRECTED, ) assert set(G.prune(["E"], ["K"]).V) == {"E", "F", "H", "K"} + + +def test_graph_hash(): + G = Graph() + G.add_edges( + [ + ("A", "B"), + ("A", "C"), + ("A", "D"), + ("D", "C"), + ("D", "E"), + ("B", "E"), + ("E", "F"), + ("A", "F"), + ], + type=EdgeType.DIRECTED, + ) + h1 = G.hash() + G.add_edge("F", "G") + h2 = G.hash() + G._edge_attr[0]["x"] = "" + h3 = G.hash() + assert h1 != h2 + assert h1 != h3 + assert h2 != h3