Skip to content

Commit

Permalink
Another fix for polars
Browse files Browse the repository at this point in the history
  • Loading branch information
marcopeix committed Jan 10, 2025
1 parent 8e3e975 commit 555ad95
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
8 changes: 6 additions & 2 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"from typing import Any, Dict, List, Optional, Sequence, Union\n",
"\n",
"import fsspec\n",
"import polars\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pytorch_lightning as pl\n",
Expand Down Expand Up @@ -1378,7 +1379,7 @@
"\n",
" # Combine all series forecasts DataFrames\n",
" if isinstance(fcsts_dfs[0], pl_DataFrame):\n",
" fcsts_df = pl.concat(fcsts_dfs, how='vertical')\n",
" fcsts_df = polars.concat(fcsts_dfs, how='vertical')\n",
" else:\n",
" fcsts_df = pd.concat(fcsts_dfs, axis=0, ignore_index=True)\n",
" \n",
Expand Down Expand Up @@ -1439,7 +1440,10 @@
" indptr\n",
" )\n",
" # Drop duplicates when step_size < h\n",
" fcsts_df = fcsts_df.drop_duplicates(subset=[self.id_col, self.time_col], keep='first')\n",
" if isinstance(fcsts_df, polars.DataFrame):\n",
" fcsts_df = fcsts_df.unique(subset=[self.id_col, self.time_col], keep='first')\n",
" else:\n",
" fcsts_df = fcsts_df.drop_duplicates(subset=[self.id_col, self.time_col], keep='first') \n",
" return fcsts_df\n",
"\n",
" # Save list of models with pytorch lightning save_checkpoint function\n",
Expand Down
14 changes: 10 additions & 4 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Dict, List, Optional, Sequence, Union

import fsspec
import polars
import numpy as np
import pandas as pd
import pytorch_lightning as pl
Expand Down Expand Up @@ -1350,7 +1351,7 @@ def predict_insample(self, step_size: int = 1):

# Combine all series forecasts DataFrames
if isinstance(fcsts_dfs[0], pl_DataFrame):
fcsts_df = pl.concat(fcsts_dfs, how="vertical")
fcsts_df = polars.concat(fcsts_dfs, how="vertical")
else:
fcsts_df = pd.concat(fcsts_dfs, axis=0, ignore_index=True)

Expand Down Expand Up @@ -1415,9 +1416,14 @@ def predict_insample(self, step_size: int = 1):
fcsts_df[invert_cols].to_numpy(), indptr
)
# Drop duplicates when step_size < h
fcsts_df = fcsts_df.drop_duplicates(
subset=[self.id_col, self.time_col], keep="first"
)
if isinstance(fcsts_df, polars.DataFrame):
fcsts_df = fcsts_df.unique(
subset=[self.id_col, self.time_col], keep="first"
)
else:
fcsts_df = fcsts_df.drop_duplicates(
subset=[self.id_col, self.time_col], keep="first"
)
return fcsts_df

# Save list of models with pytorch lightning save_checkpoint function
Expand Down

0 comments on commit 555ad95

Please sign in to comment.