Skip to content

Commit

Permalink
Merge branch 'main' into fix/docs_and_refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Oct 8, 2024
2 parents 87af3ac + 8d378c6 commit af070a9
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0

- name: Clone docs repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
with:
repository: Nixtla/docs
ref: scripts
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_NIXTLA_TMP }}
steps:
- name: Clone repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0

- name: Set up environment
uses: mamba-org/setup-micromamba@f8b8a1e23a26f60a44c853292711bacfd3eac822 # v1.9.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0

- name: Set up python
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
- name: Set up Python
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # 5.2.0
with:
Expand Down
2 changes: 1 addition & 1 deletion action_files/test_models/src/multivariate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from neuralforecast.models.tsmixer import TSMixer
from neuralforecast.models.tsmixerx import TSMixerx
from neuralforecast.models.itransformer import iTransformer
# from neuralforecast.models.stemgnn import StemGNN
# # from neuralforecast.models.stemgnn import StemGNN
from neuralforecast.models.mlpmultivariate import MLPMultivariate
from neuralforecast.models.timemixer import TimeMixer

Expand Down
9 changes: 9 additions & 0 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@
" nn.init.xavier_normal_ = xavier_normal"
]
},
{
"cell_type": "markdown",
"id": "fffc7edd",
"metadata": {},
"source": [
"`<<<<<<< HEAD`"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -590,6 +598,7 @@
" if self.val_size == 0:\n",
" return\n",
" losses = torch.stack(self.validation_step_outputs)\n",
" avg_loss = losses.mean().detach().item()\n",
" avg_loss = losses.mean().detach()\n",
" self.log(\n",
" \"ptl/val_loss\",\n",
Expand Down
5 changes: 3 additions & 2 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
" StemGNN, PatchTST, TimesNet, TimeLLM, TSMixer, TSMixerx,\n",
" MLPMultivariate, iTransformer,\n",
" BiTCN, TiDE, DeepNPTS, SOFTS,\n",
" TimeMixer, KAN\n",
" TimeMixer, KAN, RMoK\n",
")\n",
"from neuralforecast.common._base_auto import BaseAuto, MockTrial"
]
Expand Down Expand Up @@ -245,7 +245,8 @@
" 'deepnpts': DeepNPTS, 'autodeepnpts': DeepNPTS,\n",
" 'softs': SOFTS, 'autosofts': SOFTS,\n",
" 'timemixer': TimeMixer, 'autotimemixer': TimeMixer,\n",
" 'kan': KAN, 'autokan': KAN\n",
" 'kan': KAN, 'autokan': KAN,\n",
" 'rmok': RMoK, 'autormok': RMoK\n",
"}"
]
},
Expand Down
3 changes: 2 additions & 1 deletion neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def noop(*args, **kwargs):
nn.init.xavier_uniform_ = xavier_uniform
nn.init.xavier_normal_ = xavier_normal

# %% ../../nbs/common.base_model.ipynb 5
# %% ../../nbs/common.base_model.ipynb 6
class BaseModel(pl.LightningModule):
EXOGENOUS_FUTR = True # If the model can handle future exogenous variables
EXOGENOUS_HIST = True # If the model can handle historical exogenous variables
Expand Down Expand Up @@ -597,6 +597,7 @@ def on_validation_epoch_end(self):
if self.val_size == 0:
return
losses = torch.stack(self.validation_step_outputs)
avg_loss = losses.mean().detach().item()
avg_loss = losses.mean().detach()
self.log(
"ptl/val_loss",
Expand Down
3 changes: 3 additions & 0 deletions neuralforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
SOFTS,
TimeMixer,
KAN,
RMoK,
)
from .common._base_auto import BaseAuto, MockTrial

Expand Down Expand Up @@ -190,6 +191,8 @@ def _insample_times(
"autotimemixer": TimeMixer,
"kan": KAN,
"autokan": KAN,
"rmok": RMoK,
"autormok": RMoK,
}

# %% ../nbs/core.ipynb 8
Expand Down

0 comments on commit af070a9

Please sign in to comment.