Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add option to modify the default configure_optimizers() behavior #1015

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
105 changes: 56 additions & 49 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"import random\n",
"import warnings\n",
"from contextlib import contextmanager\n",
"from copy import deepcopy\n",
"from dataclasses import dataclass\n",
"\n",
"import fsspec\n",
Expand Down Expand Up @@ -89,7 +88,7 @@
" kaiming_normal = nn.init.kaiming_normal_\n",
" xavier_uniform = nn.init.xavier_uniform_\n",
" xavier_normal = nn.init.xavier_normal_\n",
" \n",
" \n",
" nn.init.kaiming_uniform_ = noop\n",
" nn.init.kaiming_normal_ = noop\n",
" nn.init.xavier_uniform_ = noop\n",
Expand Down Expand Up @@ -121,10 +120,6 @@
" random_seed,\n",
" loss,\n",
" valid_loss,\n",
" optimizer,\n",
" optimizer_kwargs,\n",
" lr_scheduler,\n",
" lr_scheduler_kwargs,\n",
" futr_exog_list,\n",
" hist_exog_list,\n",
" stat_exog_list,\n",
Expand All @@ -150,18 +145,8 @@
" self.train_trajectories = []\n",
" self.valid_trajectories = []\n",
"\n",
" # Optimization\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
"\n",
" # lr scheduler\n",
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
" self.lr_scheduler = lr_scheduler\n",
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
"\n",
" # customized by set_configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
" # Variables\n",
" self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
Expand Down Expand Up @@ -409,39 +394,61 @@
" random.seed(self.random_seed)\n",
"\n",
" def configure_optimizers(self):\n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" if self.optimizer_kwargs:\n",
" warnings.warn(\n",
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
" ) \n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" if self.config_optimizers is not None:\n",
" # return the customized optimizer settings if specified\n",
" return self.config_optimizers\n",
" \n",
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
" if self.lr_scheduler:\n",
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
" if 'optimizer' in lr_scheduler_kwargs:\n",
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
" del lr_scheduler_kwargs['optimizer']\n",
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
" else:\n",
" if self.lr_scheduler_kwargs:\n",
" warnings.warn(\n",
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
" ) \n",
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
" # default choice\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {\n",
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
" )\n",
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
" ),\n",
" \"frequency\": 1,\n",
" \"interval\": \"step\",\n",
" }\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
"\n",
" def set_configure_optimizers(\n",
" self, \n",
" optimizer=None,\n",
" scheduler=None,\n",
" interval='step',\n",
" frequency=1,\n",
" monitor='val_loss',\n",
" strict=True,\n",
" name=None\n",
" ):\n",
" \"\"\"Helper function to customize the lr_scheduler_config as detailed in \n",
" https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers\n",
"\n",
" Calling set_configure_optimizers() with valid `optimizer`, `scheduler` shall modify the returned \n",
" dictionary of key='optimizer', key='lr_scheduler' in configure_optimizers().\n",
" Note that the default choice of `interval` in set_configure_optiizers() is 'step',\n",
" which differs from the choice of 'epoch' used in lightning_module. \n",
" \"\"\"\n",
" lr_scheduler_config = {\n",
" 'interval': interval,\n",
" 'frequency': frequency,\n",
" 'monitor': monitor,\n",
" 'strict': strict,\n",
" 'name': name,\n",
" }\n",
"\n",
" if scheduler is not None and optimizer is not None:\n",
" if not isinstance(scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"scheduler is not a valid instance of torch.optim.lr_scheduler.LRScheduler\")\n",
" if not isinstance(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid instance of torch.optim.Optimizer\") \n",
" \n",
" lr_scheduler_config[\"scheduler\"] = scheduler\n",
" self.config_optimizers = {\n",
" 'optimizer': optimizer,\n",
" 'lr_scheduler': lr_scheduler_config,\n",
" }\n",
" else:\n",
" # falls back to default option as specified in configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
" def get_test_size(self):\n",
" return self.test_size\n",
Expand Down
10 changes: 1 addition & 9 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,12 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
" valid_loss=valid_loss, \n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
8 changes: 0 additions & 8 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,12 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
8 changes: 0 additions & 8 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,12 @@
" drop_last_loader=False,\n",
" random_seed=1,\n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
Loading
Loading