Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/wrapper model gridsearch #2594

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open

Conversation

madtoinou
Copy link
Collaborator

@madtoinou madtoinou commented Nov 12, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2104.

Summary

Copied from #2133.

When the model key is found in the parameters dictionary passed to the gridsearch classmethod (meaning we are in the context of a model that wraps another), the class method expects either a list of wrapped model instances , or it expects a dictionary with a special key called model_class whose value is the class of the model to be wrapped. The other keys in the dictionary are the parameters that will be used to construct the grid dedicated to the wrapped model. Example

from sklearn.ensemble import RandomForestRegressor

from darts.models import RegressionModel
from darts.utils import timeseries_generation as tg

parameters = {
    "model": {
        "model_class": RandomForestRegressor,
        "min_samples_split": [2,3],
        "min_samples_leaf": [1,2],
    },
    "lags": [1,2,3],
}
series = tg.sine_timeseries(length=100)

RegressionModel.gridsearch(
    parameters=parameters, series=series, forecast_horizon=1
)

@andresliszt
Copy link

andresliszt commented Nov 12, 2024

Seems that you added a validator on input parameters, the wrapped model dictionary is not allowed anymore. Here, I think might be loosen with an except for the model key

Copy link

codecov bot commented Nov 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.20%. Comparing base (aad1440) to head (407b38d).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2594      +/-   ##
==========================================
- Coverage   94.24%   94.20%   -0.05%     
==========================================
  Files         141      141              
  Lines       15463    15483      +20     
==========================================
+ Hits        14573    14585      +12     
- Misses        890      898       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] gridsearch with RegressionModel
3 participants