diff --git a/careless/args/optimizer.py b/careless/args/optimizer.py index 2df741c..b56f915 100644 --- a/careless/args/optimizer.py +++ b/careless/args/optimizer.py @@ -26,4 +26,22 @@ "default":0.99, }), + (("--clipnorm",), { + "help":"Optionally clip the norm of the gradient of each weight to be no larger than this value.", + "type": float, + "default": None, + }), + + (("--clipvalue",), { + "help":"Optionally clip the gradients to be no larger than this value.", + "type": float, + "default": None, + }), + + (("--global-clipnorm",), { + "help":"Optionally clip the norm of all the gradients to be no larger than this value.", + "type": float, + "default": None, + }), + ) diff --git a/careless/io/manager.py b/careless/io/manager.py index 4ab7a94..f2afadc 100644 --- a/careless/io/manager.py +++ b/careless/io/manager.py @@ -464,6 +464,9 @@ def build_model(self, parser=None, surrogate_posterior=None, prior=None, likelih parser.learning_rate, parser.beta_1, parser.beta_2, + clipnorm=parser.clipnorm, + clipvalue=parser.clipvalue, + global_clipnorm=parser.global_clipnorm, ) model.compile( diff --git a/tests/test_cli.py b/tests/test_cli.py index a6f899e..452619d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -168,3 +168,18 @@ def test_freeze_scales(off_file): out_file = out + f"_0.mtz" assert exists(out_file) +@pytest.mark.parametrize('clip_type', ['--clipvalue', '--clipnorm', '--global-clipnorm']) +def test_clipping(off_file, clip_type): + """ Test `--freeze-scales` for execution """ + with TemporaryDirectory() as td: + out = td + '/out' + flags = f"mono --disable-gpu --iterations={niter} {clip_type}=1. dHKL,image_id" + command = flags + f" {off_file} {out}" + from careless.parser import parser + parser = parser.parse_args(command.split()) + run_careless(parser) + + out_file = out + f"_0.mtz" + assert exists(out_file) + +