diff --git a/tests/toolkit/test_time_series_forecasting_pipeline.py b/tests/toolkit/test_time_series_forecasting_pipeline.py index c22b1cca..defdd123 100644 --- a/tests/toolkit/test_time_series_forecasting_pipeline.py +++ b/tests/toolkit/test_time_series_forecasting_pipeline.py @@ -120,30 +120,13 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): model = PatchTSTForPrediction.from_pretrained(model_path) context_length = model.config.context_length - tsp = TimeSeriesPreprocessor( - timestamp_column=timestamp_column, - id_columns=id_columns, - target_columns=target_columns, - context_length=context_length, - prediction_length=prediction_length, - freq="1h", - ) - - forecast_pipeline = TimeSeriesForecastingPipeline( - model=model, - timestamp_column=timestamp_column, - id_columns=id_columns, - target_columns=target_columns, - freq="1h", - feature_extractor=tsp, - explode_forecasts=False, - ) - dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv" data = pd.read_csv( dataset_path, parse_dates=[timestamp_column], ) + train_end_index = 12 * 30 * 24 + test_end_index = 12 * 30 * 24 + 8 * 30 * 24 test_start_index = test_end_index - context_length - 4 @@ -152,6 +135,12 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): parse_dates=[timestamp_column], ) + train_data = select_by_index( + data, + id_columns=id_columns, + start_index=0, + end_index=train_end_index, + ) test_data = select_by_index( data, id_columns=id_columns, @@ -159,11 +148,35 @@ def test_forecasting_pipeline_forecasts_with_preprocessor(): end_index=test_end_index, ) - forecasts = forecast_pipeline(test_data) + tsp = TimeSeriesPreprocessor( + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + context_length=context_length, + prediction_length=prediction_length, + freq="1h", + scaling=True, + ) + + tsp.train(train_data) + + forecast_pipeline = TimeSeriesForecastingPipeline( + model=model, + timestamp_column=timestamp_column, + id_columns=id_columns, + target_columns=target_columns, + freq="1h", + feature_extractor=tsp, + explode_forecasts=False, + inverse_scale_outputs=True, + ) + + forecasts = forecast_pipeline(tsp.preprocess(test_data)) assert forecasts.shape == ( test_end_index - test_start_index - context_length + 1, 2 * len(target_columns) + 1, ) - # to do: add check on the scaling + # if we have inverse scaled mean should be larger + assert forecasts["HUFL_prediction"].mean().mean() > 10