Skip to content

Commit

Permalink
Fix/deprec nn (#2593)
Browse files Browse the repository at this point in the history
* Fix deprecated usage of torch.nn.utils.weight_norm

The previous implementation in darts.darts.models.forecasting.tcn_mode was using `torch.nn.utils.weight_norm`, which is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`. This commit replaces two occurrences of `torch.nn.utils.weight_norm` with the recommended `torch.nn.utils.parametrizations.weight_norm` to resolve the deprecation warning.

* Update torch_forecasting_model.py

Corrected file saving process for checkpoint files (ckpt) to filter out occurrences of the string '.pt' from the previous file path."

* fix: revert changes

* update changelog

---------

Co-authored-by: Saeed Foroutan <[email protected]>
  • Loading branch information
madtoinou and SaeedForoutan authored Nov 14, 2024
1 parent d909589 commit d60ef87
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- fixed failing docker deployment
- removed `gradle` dependency in favor of native GitHub action plugins.
- Updated ruff to v0.7.2 and target-version to python39, also fixed various typos [#2589](https://github.com/unit8co/darts/pull/2589) by [Greg DeVosNouri](https://github.com/gdevos010) and [Antoine Madrona](https://github.com/madtoinou).
- Replaced the deprecated `torch.nn.utils.weight_norm` function with `torch.nn.utils.parametrizations.weight_norm` [#2593](https://github.com/unit8co/darts/pull/2593) by [Saeed Foroutan](https://github.com/SaeedForoutan).

## [0.31.0](https://github.com/unit8co/darts/tree/0.31.0) (2024-10-13)

Expand All @@ -40,7 +41,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

- Improvements to `metrics`:
- Added support for computing metrics on one or multiple quantiles `q`, either from probabilistic or quantile forecasts. [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader).
- Added quantile interval metrics `miw` (Mean Interval Width, time aggregated) and `iw` (Interval Width, per time step / non-aggregated) which compute the width of quantile intervals `q_intervals` (expected to be a tuple or sequence of tuples with (lower quantile, upper quantile). [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader).
- Added quantile interval metrics `miw` (Mean Interval Width, time aggregated) and `iw` (Interval Width, per time step / non-aggregated) which compute the width of quantile intervals `q_intervals` (expected to be a tuple or sequence of tuples with (lower quantile, upper quantile)). [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader).
- Improvements to `backtest()` and `residuals()`:
- Added support for computing backtest and residuals on one or multiple quantiles `q` in the `metric_kwargs`, either from probabilistic or quantile forecasts. [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader).
- Added support for parameters `enable_optimization` and `predict_likelihood_parameters`. [#2530](https://github.com/unit8co/darts/pull/2530) by [Dennis Bader](https://github.com/dennisbader).
Expand Down
4 changes: 2 additions & 2 deletions darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
)
if weight_norm:
self.conv1, self.conv2 = (
nn.utils.weight_norm(self.conv1),
nn.utils.weight_norm(self.conv2),
nn.utils.parametrizations.weight_norm(self.conv1),
nn.utils.parametrizations.weight_norm(self.conv2),
)

if input_dim != output_dim:
Expand Down

0 comments on commit d60ef87

Please sign in to comment.