Skip to content

Commit

Permalink
reuse random and plugin (#171)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 17, 2024
1 parent b1dd5ad commit 5939f3a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 152 deletions.
74 changes: 13 additions & 61 deletions deepmd_pt/utils/dp_random.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,13 @@
import numpy as np


_RANDOM_GENERATOR = np.random.RandomState()


def choice(a: np.ndarray, p: np.ndarray = None, **kwargs):
"""Generates a random sample from a given 1-D array.
Parameters
----------
a : np.ndarray
A random sample is generated from its elements.
p : np.ndarray
The probabilities associated with each entry in a.
Returns
-------
np.ndarray
arrays with results and their shapes
"""
return _RANDOM_GENERATOR.choice(a, p=p, **kwargs)


def random(size=None):
"""Return random floats in the half-open interval [0.0, 1.0).
Parameters
----------
size
Output shape.
Returns
-------
np.ndarray
Arrays with results and their shapes.
"""
print('DP Called to random')
return _RANDOM_GENERATOR.random_sample(size)


def seed(val: int = None):
"""Seed the generator.
Parameters
----------
val : int
Seed.
"""
_RANDOM_GENERATOR.seed(val)


def shuffle(x: np.ndarray):
"""Modify a sequence in-place by shuffling its contents.
Parameters
----------
x : np.ndarray
The array or list to be shuffled.
"""
_RANDOM_GENERATOR.shuffle(x)
from deepmd_utils.utils.random import (
choice,
random,
seed,
shuffle,
)

__all__ = [
"choice",
"random",
"seed",
"shuffle",
]
102 changes: 11 additions & 91 deletions deepmd_pt/utils/plugin.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Base of plugin systems."""
# copied from https://github.com/deepmodeling/dpdata/blob/a3e76d75de53f6076254de82d18605a010dc3b00/dpdata/plugin.py

from abc import (
ABCMeta,
)
from typing import (
Callable,
from deepmd_utils.utils.plugin import (
Plugin,
PluginVariant,
VariantABCMeta,
VariantMeta,
)


class Plugin:
"""A class to register and restore plugins.
Attributes
----------
plugins : Dict[str, object]
plugins
Examples
--------
>>> plugin = Plugin()
>>> @plugin.register("xx")
def xxx():
pass
>>> print(plugin.plugins['xx'])
"""

def __init__(self):
self.plugins = {}

def __add__(self, other) -> "Plugin":
self.plugins.update(other.plugins)
return self

def register(self, key: str) -> Callable[[object], object]:
"""Register a plugin.
Parameters
----------
key : str
key of the plugin
Returns
-------
Callable[[object], object]
decorator
"""

def decorator(object: object) -> object:
self.plugins[key] = object
return object

return decorator

def get_plugin(self, key) -> object:
"""Visit a plugin by key.
Parameters
----------
key : str
key of the plugin
Returns
-------
object
the plugin
"""
return self.plugins[key]


class VariantMeta:
def __call__(cls, *args, **kwargs):
"""Remove `type` and keys that starts with underline."""
obj = cls.__new__(cls, *args, **kwargs)
kwargs.pop("type", None)
to_pop = []
for kk in kwargs:
if kk[0] == "_":
to_pop.append(kk)
for kk in to_pop:
kwargs.pop(kk, None)
obj.__init__(*args, **kwargs)
return obj


class VariantABCMeta(VariantMeta, ABCMeta):
pass


class PluginVariant(metaclass=VariantABCMeta):
"""A class to remove `type` from input arguments."""

pass
__all__ = [
"Plugin",
"VariantMeta",
"VariantABCMeta",
"PluginVariant",
]

0 comments on commit 5939f3a

Please sign in to comment.