-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
2 changed files
with
24 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,13 @@ | ||
import numpy as np | ||
|
||
|
||
_RANDOM_GENERATOR = np.random.RandomState() | ||
|
||
|
||
def choice(a: np.ndarray, p: np.ndarray = None, **kwargs): | ||
"""Generates a random sample from a given 1-D array. | ||
Parameters | ||
---------- | ||
a : np.ndarray | ||
A random sample is generated from its elements. | ||
p : np.ndarray | ||
The probabilities associated with each entry in a. | ||
Returns | ||
------- | ||
np.ndarray | ||
arrays with results and their shapes | ||
""" | ||
return _RANDOM_GENERATOR.choice(a, p=p, **kwargs) | ||
|
||
|
||
def random(size=None): | ||
"""Return random floats in the half-open interval [0.0, 1.0). | ||
Parameters | ||
---------- | ||
size | ||
Output shape. | ||
Returns | ||
------- | ||
np.ndarray | ||
Arrays with results and their shapes. | ||
""" | ||
print('DP Called to random') | ||
return _RANDOM_GENERATOR.random_sample(size) | ||
|
||
|
||
def seed(val: int = None): | ||
"""Seed the generator. | ||
Parameters | ||
---------- | ||
val : int | ||
Seed. | ||
""" | ||
_RANDOM_GENERATOR.seed(val) | ||
|
||
|
||
def shuffle(x: np.ndarray): | ||
"""Modify a sequence in-place by shuffling its contents. | ||
Parameters | ||
---------- | ||
x : np.ndarray | ||
The array or list to be shuffled. | ||
""" | ||
_RANDOM_GENERATOR.shuffle(x) | ||
from deepmd_utils.utils.random import ( | ||
choice, | ||
random, | ||
seed, | ||
shuffle, | ||
) | ||
|
||
__all__ = [ | ||
"choice", | ||
"random", | ||
"seed", | ||
"shuffle", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,15 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""Base of plugin systems.""" | ||
# copied from https://github.com/deepmodeling/dpdata/blob/a3e76d75de53f6076254de82d18605a010dc3b00/dpdata/plugin.py | ||
|
||
from abc import ( | ||
ABCMeta, | ||
) | ||
from typing import ( | ||
Callable, | ||
from deepmd_utils.utils.plugin import ( | ||
Plugin, | ||
PluginVariant, | ||
VariantABCMeta, | ||
VariantMeta, | ||
) | ||
|
||
|
||
class Plugin: | ||
"""A class to register and restore plugins. | ||
Attributes | ||
---------- | ||
plugins : Dict[str, object] | ||
plugins | ||
Examples | ||
-------- | ||
>>> plugin = Plugin() | ||
>>> @plugin.register("xx") | ||
def xxx(): | ||
pass | ||
>>> print(plugin.plugins['xx']) | ||
""" | ||
|
||
def __init__(self): | ||
self.plugins = {} | ||
|
||
def __add__(self, other) -> "Plugin": | ||
self.plugins.update(other.plugins) | ||
return self | ||
|
||
def register(self, key: str) -> Callable[[object], object]: | ||
"""Register a plugin. | ||
Parameters | ||
---------- | ||
key : str | ||
key of the plugin | ||
Returns | ||
------- | ||
Callable[[object], object] | ||
decorator | ||
""" | ||
|
||
def decorator(object: object) -> object: | ||
self.plugins[key] = object | ||
return object | ||
|
||
return decorator | ||
|
||
def get_plugin(self, key) -> object: | ||
"""Visit a plugin by key. | ||
Parameters | ||
---------- | ||
key : str | ||
key of the plugin | ||
Returns | ||
------- | ||
object | ||
the plugin | ||
""" | ||
return self.plugins[key] | ||
|
||
|
||
class VariantMeta: | ||
def __call__(cls, *args, **kwargs): | ||
"""Remove `type` and keys that starts with underline.""" | ||
obj = cls.__new__(cls, *args, **kwargs) | ||
kwargs.pop("type", None) | ||
to_pop = [] | ||
for kk in kwargs: | ||
if kk[0] == "_": | ||
to_pop.append(kk) | ||
for kk in to_pop: | ||
kwargs.pop(kk, None) | ||
obj.__init__(*args, **kwargs) | ||
return obj | ||
|
||
|
||
class VariantABCMeta(VariantMeta, ABCMeta): | ||
pass | ||
|
||
|
||
class PluginVariant(metaclass=VariantABCMeta): | ||
"""A class to remove `type` from input arguments.""" | ||
|
||
pass | ||
__all__ = [ | ||
"Plugin", | ||
"VariantMeta", | ||
"VariantABCMeta", | ||
"PluginVariant", | ||
] |