Skip to content

Commit

Permalink
add scaling example
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed May 13, 2024
1 parent db7d6c6 commit d2cc911
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions tests/toolkit/test_time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,30 +120,13 @@ def test_forecasting_pipeline_forecasts_with_preprocessor():
model = PatchTSTForPrediction.from_pretrained(model_path)
context_length = model.config.context_length

tsp = TimeSeriesPreprocessor(
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
freq="1h",
)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
feature_extractor=tsp,
explode_forecasts=False,
)

dataset_path = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh2.csv"
data = pd.read_csv(
dataset_path,
parse_dates=[timestamp_column],
)
train_end_index = 12 * 30 * 24

test_end_index = 12 * 30 * 24 + 8 * 30 * 24
test_start_index = test_end_index - context_length - 4

Expand All @@ -152,18 +135,48 @@ def test_forecasting_pipeline_forecasts_with_preprocessor():
parse_dates=[timestamp_column],
)

train_data = select_by_index(
data,
id_columns=id_columns,
start_index=0,
end_index=train_end_index,
)
test_data = select_by_index(
data,
id_columns=id_columns,
start_index=test_start_index,
end_index=test_end_index,
)

forecasts = forecast_pipeline(test_data)
tsp = TimeSeriesPreprocessor(
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
context_length=context_length,
prediction_length=prediction_length,
freq="1h",
scaling=True,
)

tsp.train(train_data)

forecast_pipeline = TimeSeriesForecastingPipeline(
model=model,
timestamp_column=timestamp_column,
id_columns=id_columns,
target_columns=target_columns,
freq="1h",
feature_extractor=tsp,
explode_forecasts=False,
inverse_scale_outputs=True,
)

forecasts = forecast_pipeline(tsp.preprocess(test_data))

assert forecasts.shape == (
test_end_index - test_start_index - context_length + 1,
2 * len(target_columns) + 1,
)

# to do: add check on the scaling
# if we have inverse scaled mean should be larger
assert forecasts["HUFL_prediction"].mean().mean() > 10

0 comments on commit d2cc911

Please sign in to comment.