From e1944c013000fdc0eca9e49973e4bdd18905923f Mon Sep 17 00:00:00 2001 From: nardew <28791551+nardew@users.noreply.github.com> Date: Mon, 25 Mar 2024 22:32:37 +0100 Subject: [PATCH] Fix NPE in ADX --- talipp/indicator_util.py | 7 +++++++ talipp/indicators/ADX.py | 16 ++++++++++++---- tests/test_indicator_util.py | 9 ++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/talipp/indicator_util.py b/talipp/indicator_util.py index 0905ba9..563bc3a 100644 --- a/talipp/indicator_util.py +++ b/talipp/indicator_util.py @@ -15,6 +15,13 @@ def has_valid_values(sequence: Union[Indicator, List[Any]], window: int = 1, exa (len(sequence) > window and sequence[-window] is not None and sequence[-window-1] is None) +def previous_if_exists(sequence: Union[Indicator, List[Any]], previous_index: int = -1, default: Any = 0): + try: + return sequence[previous_index] + except IndexError: + return default + + def composite_to_lists(indicator: Indicator) -> Dict[str, List[float]]: if not has_valid_values(indicator, 1): return {} diff --git a/talipp/indicators/ADX.py b/talipp/indicators/ADX.py index 21c286c..bc1f6da 100644 --- a/talipp/indicators/ADX.py +++ b/talipp/indicators/ADX.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import List, Any -from talipp.indicator_util import has_valid_values +from talipp.indicator_util import has_valid_values, previous_if_exists from talipp.indicators.ATR import ATR from talipp.indicators.Indicator import Indicator, InputModifierType from talipp.input import SamplingPeriodType @@ -94,10 +94,18 @@ def _calculate_new_value(self) -> Any: self.spdm.append((self.spdm[-1] * (self.di_period - 1) + self.pdm[-1]) / float(self.di_period)) self.smdm.append((self.smdm[-1] * (self.di_period - 1) + self.mdm[-1]) / float(self.di_period)) - self.pdi.append(100.0 * self.spdm[-1] / float(self.atr[-1])) - self.mdi.append(100.0 * self.smdm[-1] / float(self.atr[-1])) + if self.atr[-1] != 0: + self.pdi.append(100.0 * self.spdm[-1] / float(self.atr[-1])) + self.mdi.append(100.0 * self.smdm[-1] / float(self.atr[-1])) + else: + self.pdi.append(previous_if_exists(self.pdi)) + self.mdi.append(previous_if_exists(self.mdi)) - self.dx.append(100.0 * float(abs(self.pdi[-1] - self.mdi[-1])) / (self.pdi[-1] + self.mdi[-1])) + dx_denom = (self.pdi[-1] + self.mdi[-1]) + if dx_denom != 0: + self.dx.append(100.0 * float(abs(self.pdi[-1] - self.mdi[-1])) / dx_denom) + else: + self.dx.append(previous_if_exists(self.dx, default=0)) adx = None if len(self.dx) == self.adx_period: diff --git a/tests/test_indicator_util.py b/tests/test_indicator_util.py index 3f28996..19718bd 100644 --- a/tests/test_indicator_util.py +++ b/tests/test_indicator_util.py @@ -1,6 +1,6 @@ import unittest -from talipp.indicator_util import composite_to_lists, has_valid_values +from talipp.indicator_util import composite_to_lists, has_valid_values, previous_if_exists from talipp.indicators import BB, SMA @@ -53,6 +53,13 @@ def test_has_valid_values(self): self.assertFalse(has_valid_values([None], 1)) self.assertFalse(has_valid_values([1, None], 1)) + def test_previous_if_exists(self): + self.assertEqual(previous_if_exists([]), 0) + self.assertEqual(previous_if_exists([], default=1), 1) + self.assertEqual(previous_if_exists([1]), 1) + self.assertEqual(previous_if_exists([1], previous_index=-2), 0) + self.assertEqual(previous_if_exists([1,2], previous_index=-2), 1) + if __name__ == '__main__': unittest.main()