diff --git a/deepmd_pt/utils/dp_random.py b/deepmd_pt/utils/dp_random.py index 3062715..81ecead 100644 --- a/deepmd_pt/utils/dp_random.py +++ b/deepmd_pt/utils/dp_random.py @@ -1,61 +1,13 @@ -import numpy as np - - -_RANDOM_GENERATOR = np.random.RandomState() - - -def choice(a: np.ndarray, p: np.ndarray = None, **kwargs): - """Generates a random sample from a given 1-D array. - - Parameters - ---------- - a : np.ndarray - A random sample is generated from its elements. - p : np.ndarray - The probabilities associated with each entry in a. - - Returns - ------- - np.ndarray - arrays with results and their shapes - """ - return _RANDOM_GENERATOR.choice(a, p=p, **kwargs) - - -def random(size=None): - """Return random floats in the half-open interval [0.0, 1.0). - - Parameters - ---------- - size - Output shape. - - Returns - ------- - np.ndarray - Arrays with results and their shapes. - """ - print('DP Called to random') - return _RANDOM_GENERATOR.random_sample(size) - - -def seed(val: int = None): - """Seed the generator. - - Parameters - ---------- - val : int - Seed. - """ - _RANDOM_GENERATOR.seed(val) - - -def shuffle(x: np.ndarray): - """Modify a sequence in-place by shuffling its contents. - - Parameters - ---------- - x : np.ndarray - The array or list to be shuffled. - """ - _RANDOM_GENERATOR.shuffle(x) +from deepmd_utils.utils.random import ( + choice, + random, + seed, + shuffle, +) + +__all__ = [ + "choice", + "random", + "seed", + "shuffle", +] \ No newline at end of file diff --git a/deepmd_pt/utils/plugin.py b/deepmd_pt/utils/plugin.py index 2a77b74..dbc0237 100644 --- a/deepmd_pt/utils/plugin.py +++ b/deepmd_pt/utils/plugin.py @@ -1,95 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Base of plugin systems.""" -# copied from https://github.com/deepmodeling/dpdata/blob/a3e76d75de53f6076254de82d18605a010dc3b00/dpdata/plugin.py - -from abc import ( - ABCMeta, -) -from typing import ( - Callable, +from deepmd_utils.utils.plugin import ( + Plugin, + PluginVariant, + VariantABCMeta, + VariantMeta, ) - -class Plugin: - """A class to register and restore plugins. - - Attributes - ---------- - plugins : Dict[str, object] - plugins - - Examples - -------- - >>> plugin = Plugin() - >>> @plugin.register("xx") - def xxx(): - pass - >>> print(plugin.plugins['xx']) - """ - - def __init__(self): - self.plugins = {} - - def __add__(self, other) -> "Plugin": - self.plugins.update(other.plugins) - return self - - def register(self, key: str) -> Callable[[object], object]: - """Register a plugin. - - Parameters - ---------- - key : str - key of the plugin - - Returns - ------- - Callable[[object], object] - decorator - """ - - def decorator(object: object) -> object: - self.plugins[key] = object - return object - - return decorator - - def get_plugin(self, key) -> object: - """Visit a plugin by key. - - Parameters - ---------- - key : str - key of the plugin - - Returns - ------- - object - the plugin - """ - return self.plugins[key] - - -class VariantMeta: - def __call__(cls, *args, **kwargs): - """Remove `type` and keys that starts with underline.""" - obj = cls.__new__(cls, *args, **kwargs) - kwargs.pop("type", None) - to_pop = [] - for kk in kwargs: - if kk[0] == "_": - to_pop.append(kk) - for kk in to_pop: - kwargs.pop(kk, None) - obj.__init__(*args, **kwargs) - return obj - - -class VariantABCMeta(VariantMeta, ABCMeta): - pass - - -class PluginVariant(metaclass=VariantABCMeta): - """A class to remove `type` from input arguments.""" - - pass +__all__ = [ + "Plugin", + "VariantMeta", + "VariantABCMeta", + "PluginVariant", +]