From db2a7b536b7a99852e7607c8f8a1591b66ea8341 Mon Sep 17 00:00:00 2001 From: Dawid Date: Thu, 12 Oct 2023 09:24:14 -0700 Subject: [PATCH 1/2] Fix usage of parallel in EEMD [Github] #145 --- PyEMD/EEMD.py | 7 +++---- PyEMD/tests/test_eemd.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/PyEMD/EEMD.py b/PyEMD/EEMD.py index 1391dae..c370abb 100644 --- a/PyEMD/EEMD.py +++ b/PyEMD/EEMD.py @@ -103,8 +103,6 @@ def __call__( def __getstate__(self) -> Dict: self_dict = self.__dict__.copy() - if "pool" in self_dict: - del self_dict["pool"] return self_dict def generate_noise(self, scale: float, size: Union[int, Sequence[int]]) -> np.ndarray: @@ -183,13 +181,14 @@ def eemd( # For trial number of iterations perform EMD on a signal # with added white noise if self.parallel: - map_pool = Pool(processes=self.processes) + pool = Pool(processes=self.processes) + map_pool = pool.map else: map_pool = map all_IMFs = map_pool(self._trial_update, range(self.trials)) if self.parallel: - map_pool.close() + pool.close() self._all_imfs = defaultdict(list) it = iter if not progress else lambda x: tqdm(x, desc="EEMD", total=self.trials) diff --git a/PyEMD/tests/test_eemd.py b/PyEMD/tests/test_eemd.py index d1426a6..8c4880f 100644 --- a/PyEMD/tests/test_eemd.py +++ b/PyEMD/tests/test_eemd.py @@ -139,7 +139,16 @@ def test_eemd_notParallel(self): self.assertTrue(eIMFs.shape[0] > 0) self.assertTrue(eIMFs.shape[1], len(S)) - self.assertFalse("pool" in eemd.__dict__) + + def test_eemd_yesParallel(self): + S = np.random.random(100) + + eemd = EEMD(trials=5, max_imf=2, parallel=True) + eemd.EMD.FIXE_H = 2 + eIMFs = eemd.eemd(S) + + self.assertTrue(eIMFs.shape[0] > 0) + self.assertTrue(eIMFs.shape[1], len(S)) def test_imfs_and_residue_accessor(self): S = np.random.random(100) From a86e77f9927bca6b8acea8849deacb7297d3e10e Mon Sep 17 00:00:00 2001 From: Dawid Date: Thu, 12 Oct 2023 09:27:24 -0700 Subject: [PATCH 2/2] bump version --- PyEMD/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyEMD/__init__.py b/PyEMD/__init__.py index ce0c222..d53369a 100644 --- a/PyEMD/__init__.py +++ b/PyEMD/__init__.py @@ -1,6 +1,6 @@ import logging -__version__ = "1.5.1" +__version__ = "1.5.2" logger = logging.getLogger("pyemd") from PyEMD.CEEMDAN import CEEMDAN # noqa