-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
adcdc41
commit f977df0
Showing
8 changed files
with
203 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from corneto.extensions._numba import OptionalNumba | ||
|
||
numba = OptionalNumba() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from corneto.extensions._optional import OptionalModule | ||
|
||
|
||
class OptionalNumba(OptionalModule): | ||
def __init__(self): | ||
super().__init__("numba") | ||
|
||
def _create_dummy(self, name): # type: ignore | ||
if name in ["uint16", "uint32", "uint64", "int16", "int32", "int64"]: | ||
return int | ||
if name in ["float32", "float64", "complex64", "complex128"]: | ||
return float | ||
if name == "prange": | ||
return range | ||
|
||
return super()._create_dummy(name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import importlib | ||
from functools import wraps | ||
|
||
from corneto._settings import LOGGER | ||
|
||
|
||
class OptionalModule: | ||
def __init__(self, module_name): | ||
self.module_name = module_name | ||
self.module = None | ||
try: | ||
self.module = importlib.import_module(module_name) | ||
except ImportError: | ||
LOGGER.debug(f"Optional module {module_name} not found.") | ||
|
||
def __getattr__(self, item): | ||
if self.module: | ||
return getattr(self.module, item) | ||
else: | ||
return self._create_dummy(item) | ||
|
||
def _create_dummy(self, name): | ||
def _dummy(*args, **kwargs): | ||
def _decorator(func): | ||
@wraps(func) | ||
def _wrapped_func(*_args, **_kwargs): | ||
return func(*_args, **_kwargs) | ||
|
||
return _wrapped_func | ||
|
||
return _decorator | ||
|
||
return _dummy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import numpy as np | ||
|
||
from corneto.extensions import numba | ||
|
||
|
||
@numba.jit(nopython=True, fastmath=True, parallel=True) | ||
def _loop(x): | ||
r = np.empty_like(x) | ||
n = len(x) | ||
for i in range(n): | ||
r[i] = np.cos(x[i]) ** 2 + np.sin(x[i]) ** 2 | ||
return r | ||
|
||
|
||
@numba.vectorize(["float64(float64, float64)"], target="parallel") | ||
def vec_sum(x, y): | ||
return x + y | ||
|
||
|
||
@numba.guvectorize(["void(float64[:], intp[:], float64[:])"], "(n),()->(n)") | ||
def move_mean(a, window_arr, out): | ||
window_width = window_arr[0] | ||
asum = 0.0 | ||
count = 0 | ||
for i in range(window_width): | ||
asum += a[i] | ||
count += 1 | ||
out[i] = asum / count | ||
for i in range(window_width, len(a)): | ||
asum += a[i] - a[i - window_width] | ||
out[i] = asum / count | ||
|
||
|
||
def test_guvectorize_numba(): | ||
arr = np.arange(20, dtype=np.float64).reshape(2, 10) | ||
result = move_mean(arr, 3) | ||
expected = np.array( | ||
[ | ||
[0.0, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], | ||
[10.0, 10.5, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0], | ||
] | ||
) | ||
np.testing.assert_allclose(result, expected) | ||
|
||
|
||
def test_jit_loop_numba(): | ||
_loop(np.ones(10000)) | ||
assert True | ||
|
||
|
||
def test_vec_sum_numba(): | ||
vec_sum(np.ones(10000), np.ones(10000)) | ||
assert True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters