diff --git a/src/timesfm/timesfm_base.py b/src/timesfm/timesfm_base.py index 9e2c0ec..088cfab 100644 --- a/src/timesfm/timesfm_base.py +++ b/src/timesfm/timesfm_base.py @@ -50,13 +50,13 @@ def moving_average(arr, window_size): def freq_map(freq: str): """Returns the frequency map for the given frequency string.""" freq = str.upper(freq) - if (freq.endswith("H") or freq.endswith("T") or freq.endswith("MIN") or - freq.endswith("D") or freq.endswith("B") or freq.endswith("U") or - freq.endswith("S")): + if freq.endswith("MS"): + return 1 + elif freq.endswith(("H", "T", "MIN", "D", "B", "U", "S")): return 0 - elif freq.endswith(("W", "M", "MS")): + elif freq.endswith(("W", "M")): return 1 - elif freq.endswith("Y") or freq.endswith("Q") or freq.endswith("A"): + elif freq.endswith(("Y", "Q", "A")): return 2 else: raise ValueError(f"Invalid frequency: {freq}")