-
Notifications
You must be signed in to change notification settings - Fork 900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/neural prophet #1436
Draft
BlazejNowicki
wants to merge
25
commits into
master
Choose a base branch
from
feat/neural-prophet
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Feat/neural prophet #1436
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
b139ee7
First draft
BlazejNowickiU8 2422e47
Allow multivariate time series
BlazejNowickiU8 906de4c
Add examples and improve conversion
BlazejNowickiU8 f9937d0
Attempt at global model with past covariates
BlazejNowickiU8 b4f2f18
Add past covariates
BlazejNowickiU8 bed3700
Add future covariates
BlazejNowickiU8 1d293fe
Update requirements
BlazejNowickiU8 4bacbe4
Merge branch 'master' into feat/neural-prophet
hrzn 2b6c875
Merge branch 'master' into feat/neural-prophet
BlazejNowicki 67ea1d9
Test with newer version
BlazejNowickiU8 f7f2ad4
Merge branch 'feat/neural-prophet' of github.com:unit8co/darts into f…
BlazejNowickiU8 b180e87
Test rollback
BlazejNowickiU8 632b850
Manually add tensorboardX
BlazejNowickiU8 14f215a
Remove unused imports
BlazejNowickiU8 00ef06e
Merge branch 'master' into feat/neural-prophet
piaz97 75c5888
Merge branch 'master' into feat/neural-prophet
hrzn 163269f
Merge branch 'master' into feat/neural-prophet
hrzn 7b41571
Merge branch 'master' into feat/neural-prophet
hrzn 15c8f9a
Merge branch 'master' into feat/neural-prophet
BlazejNowickiU8 78c5e76
Require neural prophet with updated requirements
BlazejNowickiU8 4116e17
Revert changes from the notebooks
BlazejNowickiU8 8c449b1
Add model import in module init file
BlazejNowickiU8 35b88a6
Add docstring
BlazejNowickiU8 94db772
Merge branch 'master' into feat/neural-prophet
BlazejNowicki 3a6d8cc
Merge branch 'master' into feat/neural-prophet
dennisbader File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
""" | ||
Neural Prophet | ||
------------ | ||
""" | ||
|
||
import warnings | ||
from typing import Optional, Sequence, Union | ||
|
||
import neuralprophet | ||
import pandas as pd | ||
from neuralprophet.utils import fcst_df_to_latest_forecast | ||
|
||
from darts.logging import raise_if_not | ||
from darts.models.forecasting.forecasting_model import ForecastingModel | ||
from darts.timeseries import TimeSeries, concatenate | ||
|
||
|
||
class NeuralProphet(ForecastingModel): | ||
def __init__(self, n_lags: int = 0, n_forecasts: int = 1, **kwargs): | ||
super().__init__() | ||
# TODO improve passing arguments to the model | ||
|
||
raise_if_not(n_lags >= 0, "Argument n_lags should be a non-negative integer") | ||
|
||
self.n_lags = n_lags | ||
self.n_forecasts = n_forecasts | ||
self.model = neuralprophet.NeuralProphet( | ||
n_lags=n_lags, n_forecasts=n_forecasts, **kwargs | ||
) | ||
|
||
def fit( | ||
self, | ||
series: TimeSeries, | ||
past_covariates: Optional[TimeSeries] = None, | ||
future_covariates: Optional[TimeSeries] = None, | ||
) -> "NeuralProphet": | ||
super().fit(series) | ||
|
||
raise_if_not( | ||
series.has_datetime_index, | ||
"NeuralProphet model is limited to TimeSeries indexed with DatetimeIndex", | ||
) | ||
|
||
raise_if_not( | ||
past_covariates is None or self.n_lags > 0, | ||
"Past covariates are only supported when auto-regression is enabled (n_lags > 1)", | ||
) | ||
|
||
self.training_series = series | ||
fit_df = self._convert_ts_to_df(series) | ||
|
||
if past_covariates is not None: | ||
fit_df = self._add_past_covariates(self.model, fit_df, past_covariates) | ||
|
||
if future_covariates is not None: | ||
fit_df = self._add_future_covariates(self.model, fit_df, future_covariates) | ||
self.future_components = future_covariates.components | ||
else: | ||
self.future_components = None | ||
|
||
with warnings.catch_warnings(): | ||
self.model.fit(fit_df, freq=series.freq_str) | ||
|
||
self.fit_df = fit_df | ||
return self | ||
|
||
def predict( | ||
self, | ||
n: int, | ||
future_covariates: Optional[TimeSeries] = None, | ||
num_samples: int = 1, | ||
verbose: bool = False, | ||
) -> Union[TimeSeries, Sequence[TimeSeries]]: | ||
super().predict(n, num_samples) | ||
|
||
raise_if_not( | ||
self.n_lags == 0 or n <= self.n_forecasts, | ||
"Auto-regression has been enabled. `n` must be smaller than or equal to" | ||
"`n_forecasts` parameter in the constructor.", | ||
) | ||
|
||
self._future_covariates_checks(future_covariates) | ||
|
||
regressors_df = ( | ||
self._future_covariates_df(future_covariates) | ||
if self.future_components is not None | ||
else None | ||
) | ||
|
||
future_df = self.model.make_future_dataframe( | ||
df=self.fit_df, regressors_df=regressors_df, periods=n | ||
) | ||
|
||
with warnings.catch_warnings(): | ||
forecast_df = self.model.predict(future_df) | ||
|
||
return self._convert_df_to_ts( | ||
forecast_df, | ||
self.training_series.end_time(), | ||
self.training_series.components, | ||
) | ||
|
||
def _convert_ts_to_df(self, series: TimeSeries) -> pd.DataFrame: | ||
"""Convert TimeSeries to pandas DataFrame format required by Neural Prophet""" | ||
dfs = [] # ID y | ||
|
||
for component in series.components: | ||
component_df = ( | ||
series[component] | ||
.pd_dataframe(copy=False) | ||
.reset_index(names=["ds"]) | ||
.filter(items=["ds", component]) | ||
.rename(columns={component: "y"}) | ||
) | ||
component_df["ID"] = component | ||
dfs.append(component_df) | ||
|
||
return pd.concat(dfs).copy(deep=True) | ||
|
||
def _add_past_covariates( | ||
self, | ||
model: neuralprophet.NeuralProphet, | ||
df: pd.DataFrame, | ||
covariates: TimeSeries, | ||
): | ||
df = self._add_covariate(df, covariates) | ||
model.add_lagged_regressor(names=list(covariates.components)) | ||
return df | ||
|
||
def _add_future_covariates( | ||
self, | ||
model: neuralprophet.NeuralProphet, | ||
df: pd.DataFrame, | ||
covariates: TimeSeries, | ||
): | ||
df = self._add_covariate(df, covariates) | ||
for component in covariates.components: | ||
model.add_future_regressor(name=component) | ||
|
||
return df | ||
|
||
def _add_covariate( | ||
self, | ||
df: pd.DataFrame, | ||
covariates: TimeSeries, | ||
) -> pd.DataFrame: | ||
"""Convert past covariates from TimeSeries and add them to DataFrame""" | ||
|
||
raise_if_not( | ||
self.training_series.freq == covariates.freq, | ||
"Covariate TimeSeries has to have the same frequency as the TimeSeries that model is fitted on.", | ||
) | ||
|
||
raise_if_not( | ||
covariates.start_time() <= self.training_series.start_time() | ||
and self.training_series.end_time() <= covariates.end_time(), | ||
"Covaraite TimeSeries has to span across all TimeSeries that model is fitted on", | ||
) | ||
|
||
for component in covariates.components: | ||
covariate_df = ( | ||
covariates[component] | ||
.pd_dataframe(copy=False) | ||
.reset_index(names=["ds"]) | ||
.filter(items=["ds", component]) | ||
) | ||
|
||
df = df.merge(covariate_df, how="left", on="ds") | ||
|
||
return df | ||
|
||
def _convert_df_to_ts(self, forecast: pd.DataFrame, last_train_date, components): | ||
groups = [] | ||
for component in components: | ||
if self.n_lags == 0: | ||
# output format is different when AR is not enabled | ||
groups.append( | ||
forecast[ | ||
(forecast["ID"] == component) | ||
& (forecast["ds"] > last_train_date) | ||
] | ||
.filter(items=["ds", "yhat1"]) | ||
.rename(columns={"yhat1": component}) | ||
) | ||
else: | ||
df = fcst_df_to_latest_forecast( | ||
forecast[(forecast["ID"] == component)], | ||
quantiles=[0.5], | ||
n_last=1, | ||
) | ||
groups.append( | ||
df[df["ds"] > last_train_date] | ||
.filter(items=["ds", "origin-0"]) | ||
.rename(columns={"origin-0": component}) | ||
) | ||
|
||
return concatenate( | ||
[TimeSeries.from_dataframe(group, time_col="ds") for group in groups], | ||
axis=1, | ||
) | ||
|
||
def _future_covariates_df(self, series: TimeSeries) -> pd.DataFrame: | ||
component_dfs = [] | ||
for component in series.components: | ||
component_dfs.append(series[component].pd_dataframe()) | ||
|
||
return pd.concat(component_dfs, axis=1).reset_index(names=["ds"]) | ||
|
||
def _future_covariates_checks(self, future_covariates: Optional[TimeSeries]): | ||
raise_if_not( | ||
self.future_components is None | ||
or ( | ||
future_covariates is not None | ||
and set(self.future_components) == set(future_covariates.components) | ||
), | ||
f"Missing future covariate TimeSeries. Model was trained with {self.future_components} " | ||
"future components", | ||
) | ||
|
||
raise_if_not( | ||
self.future_components is None | ||
or future_covariates.freq == self.training_series.freq, | ||
"Invalid frequency in future covariate TimeSeries", | ||
) | ||
|
||
def uses_future_covariates(self): | ||
return True | ||
|
||
def __str__(self): | ||
return "Neural Prophet" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few generic comments about the PR:
darts/models/__init__.py
?_model_encoder_settings
). But this can wait (don't spend more time on this until the dependency situation is figured out).