Skip to content

Commit

Permalink
Fix usage of parallel in EEMD [Github] #145 (#146)
Browse files Browse the repository at this point in the history
* Fix usage of parallel in EEMD [Github] #145

* bump version
  • Loading branch information
laszukdawid authored Oct 12, 2023
1 parent 7361a88 commit c160da0
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
7 changes: 3 additions & 4 deletions PyEMD/EEMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def __call__(

def __getstate__(self) -> Dict:
self_dict = self.__dict__.copy()
if "pool" in self_dict:
del self_dict["pool"]
return self_dict

def generate_noise(self, scale: float, size: Union[int, Sequence[int]]) -> np.ndarray:
Expand Down Expand Up @@ -183,13 +181,14 @@ def eemd(
# For trial number of iterations perform EMD on a signal
# with added white noise
if self.parallel:
map_pool = Pool(processes=self.processes)
pool = Pool(processes=self.processes)
map_pool = pool.map
else:
map_pool = map
all_IMFs = map_pool(self._trial_update, range(self.trials))

if self.parallel:
map_pool.close()
pool.close()

self._all_imfs = defaultdict(list)
it = iter if not progress else lambda x: tqdm(x, desc="EEMD", total=self.trials)
Expand Down
2 changes: 1 addition & 1 deletion PyEMD/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

__version__ = "1.5.1"
__version__ = "1.5.2"
logger = logging.getLogger("pyemd")

from PyEMD.CEEMDAN import CEEMDAN # noqa
Expand Down
11 changes: 10 additions & 1 deletion PyEMD/tests/test_eemd.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,16 @@ def test_eemd_notParallel(self):

self.assertTrue(eIMFs.shape[0] > 0)
self.assertTrue(eIMFs.shape[1], len(S))
self.assertFalse("pool" in eemd.__dict__)

def test_eemd_yesParallel(self):
S = np.random.random(100)

eemd = EEMD(trials=5, max_imf=2, parallel=True)
eemd.EMD.FIXE_H = 2
eIMFs = eemd.eemd(S)

self.assertTrue(eIMFs.shape[0] > 0)
self.assertTrue(eIMFs.shape[1], len(S))

def test_imfs_and_residue_accessor(self):
S = np.random.random(100)
Expand Down

0 comments on commit c160da0

Please sign in to comment.