Skip to content

Commit

Permalink
Add constant lr scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Feb 6, 2023
1 parent 6a151de commit 6e90ce8
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 20 deletions.
5 changes: 1 addition & 4 deletions configs/config_32x32_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
"weight_decay": 1e-3
},
"lr_sched": {
"type": "inverse",
"inv_gamma": 20000.0,
"power": 1.0,
"warmup": 0.99
"type": "constant"
},
"ema_sched": {
"type": "inverse",
Expand Down
5 changes: 1 addition & 4 deletions configs/config_32x32_small_butterflies.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@
"weight_decay": 1e-3
},
"lr_sched": {
"type": "inverse",
"inv_gamma": 20000.0,
"power": 1.0,
"warmup": 0.99
"type": "constant"
},
"ema_sched": {
"type": "inverse",
Expand Down
5 changes: 1 addition & 4 deletions configs/config_cifar10.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
"weight_decay": 1e-3
},
"lr_sched": {
"type": "inverse",
"inv_gamma": 20000.0,
"power": 1.0,
"warmup": 0.99
"type": "constant"
},
"ema_sched": {
"type": "inverse",
Expand Down
5 changes: 1 addition & 4 deletions configs/config_mnist.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
"weight_decay": 1e-3
},
"lr_sched": {
"type": "inverse",
"inv_gamma": 20000.0,
"power": 1.0,
"warmup": 0.99
"type": "constant"
},
"ema_sched": {
"type": "inverse",
Expand Down
5 changes: 1 addition & 4 deletions k_diffusion/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ def load_config(file):
'weight_decay': 1e-3,
},
'lr_sched': {
'type': 'inverse',
'inv_gamma': 20000.,
'power': 1.,
'warmup': 0.99,
'type': 'constant',
},
'ema_sched': {
'type': 'inverse',
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def main():
num_steps=sched_config['num_steps'],
decay=sched_config['decay'],
warmup=sched_config['warmup'])
elif sched_config['type'] == 'constant':
sched = optim.lr_scheduler.LambdaLR(opt, lambda _: 1.0)
else:
raise ValueError('Invalid schedule type')

Expand Down

0 comments on commit 6e90ce8

Please sign in to comment.