Skip to content

Commit

Permalink
quality
Browse files Browse the repository at this point in the history
Signed-off-by: Wesley M. Gifford <[email protected]>
  • Loading branch information
wgifford committed Apr 3, 2024
1 parent 2dbd15d commit 2e66f31
Showing 1 changed file with 9 additions and 31 deletions.
40 changes: 9 additions & 31 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@


@add_end_docstrings(
build_pipeline_init_args(
has_tokenizer=False, has_feature_extractor=True, has_image_processor=False
)
build_pipeline_init_args(has_tokenizer=False, has_feature_extractor=True, has_image_processor=False)
)
class TimeSeriesForecastingPipeline(Pipeline):
"""Hugging Face Pipeline for Time Series Forecasting
Expand Down Expand Up @@ -67,9 +65,7 @@ def _sanitize_parameters(self, **kwargs):
"""

context_length = kwargs.get("context_length", self.model.config.context_length)
prediction_length = kwargs.get(
"prediction_length", self.model.config.prediction_length
)
prediction_length = kwargs.get("prediction_length", self.model.config.prediction_length)

preprocess_kwargs = {
"prediction_length": prediction_length,
Expand Down Expand Up @@ -191,9 +187,7 @@ def __call__(

return super().__call__(time_series, **kwargs)

def preprocess(
self, time_series, **kwargs
) -> Dict[str, Union[GenericTensor, List[Any]]]:
def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, List[Any]]]:
"""Preprocess step
Load the data, if not already loaded, and then generate a pytorch dataset.
"""
Expand Down Expand Up @@ -221,16 +215,12 @@ def preprocess(
# do we need to check the timestamp column?
pass
else:
raise ValueError(
f"`future_time_series` of type {type(future_time_series)} is not supported."
)
raise ValueError(f"`future_time_series` of type {type(future_time_series)} is not supported.")

# stack the time series
for c in future_time_series.columns:
if c not in time_series.columns:
raise ValueError(
f"Future time series input contains an unknown column {c}."
)
raise ValueError(f"Future time series input contains an unknown column {c}.")

time_series = pd.concat((time_series, future_time_series), axis=0)
else:
Expand Down Expand Up @@ -291,11 +281,7 @@ def _forward(self, model_inputs, **kwargs):

# copy the other inputs
copy_inputs = True
for k in [
akey
for akey in model_inputs.keys()
if (akey not in model_input_keys) or copy_inputs
]:
for k in [akey for akey in model_inputs.keys() if (akey not in model_input_keys) or copy_inputs]:
model_outputs[k] = model_inputs[k]

return model_outputs
Expand All @@ -307,20 +293,14 @@ def postprocess(self, input, **kwargs):
"""
out = {}

model_output_key = (
"prediction_outputs"
if "prediction_outputs" in input.keys()
else "prediction_logits"
)
model_output_key = "prediction_outputs" if "prediction_outputs" in input.keys() else "prediction_logits"

# name the predictions of target columns
# outputs should only have size equal to target columns
prediction_columns = []
for i, c in enumerate(kwargs["target_columns"]):
prediction_columns.append(f"{c}_prediction")
out[prediction_columns[-1]] = (
input[model_output_key][:, :, i].numpy().tolist()
)
out[prediction_columns[-1]] = input[model_output_key][:, :, i].numpy().tolist()
# provide the ground truth values for the targets
# when future is unknown, we will have augmented the provided dataframe with NaN values to cover the future
for i, c in enumerate(kwargs["target_columns"]):
Expand Down Expand Up @@ -366,8 +346,6 @@ def postprocess(self, input, **kwargs):

# inverse scale if we have a feature extractor
if self.feature_extractor is not None:
out = self.feature_extractor.inverse_scale_targets(
out, suffix="_prediction"
)
out = self.feature_extractor.inverse_scale_targets(out, suffix="_prediction")

return out

0 comments on commit 2e66f31

Please sign in to comment.