Skip to content

Commit

Permalink
Improve Numerical Stability (#167)
Browse files Browse the repository at this point in the history
* revert Exp to Softplus for scale bijector

* oops remove silly cruft

* add choice scale bijector, set softplus default

---------

Co-authored-by: Kevin Dalton <[email protected]>
  • Loading branch information
kmdalton and Kevin Dalton authored Sep 7, 2024
1 parent e99fbd0 commit 62c6859
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
8 changes: 8 additions & 0 deletions careless/args/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,12 @@
"dest" : "use_image_scales",
"default": True,
}),

(("--scale-bijector",), {
"help": "What function to use to ensure positivity of the standard deviation of scales. "
"Supported functions are --scale-bijector=exp and the default is --scale-bijector=softplus",
"type": str,
"default": "softplus",
"choices" : ["exp", "softplus"],
}),
)
18 changes: 17 additions & 1 deletion careless/io/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,21 @@ def build_model(self, parser=None, surrogate_posterior=None, prior=None, likelih
if mlp_width is None:
mlp_width = BaseModel.get_metadata(self.inputs).shape[-1]

if parser.scale_bijector.lower() == 'softplus':
from tensorflow_probability import bijectors as tfb
scale_bijector = tfb.Chain([
tfb.Shift(parser.epsilon),
tfb.Softplus(),
])
elif parser.scale_bijector.lower() == 'exp':
from tensorflow_probability import bijectors as tfb
scale_bijector = tfb.Chain([
tfb.Shift(parser.epsilon),
tfb.Exp(),
])
else:
raise ValueError(f"Unsupported scale bijector type, {parser.scale_bijector}")

if parser.image_layers > 0:
from careless.models.scaling.image import NeuralImageScaler
n_images = np.max(BaseModel.get_image_id(self.inputs)) + 1
Expand All @@ -447,9 +462,10 @@ def build_model(self, parser=None, surrogate_posterior=None, prior=None, likelih
parser.mlp_layers,
mlp_width,
epsilon=parser.epsilon,
scale_bijector=scale_bijector
)
else:
mlp_scaler = MLPScaler(parser.mlp_layers, mlp_width, epsilon=parser.epsilon)
mlp_scaler = MLPScaler(parser.mlp_layers, mlp_width, epsilon=parser.epsilon, scale_bijector=scale_bijector)
if parser.use_image_scales:
n_images = np.max(BaseModel.get_image_id(self.inputs)) + 1
image_scaler = ImageScaler(n_images)
Expand Down
6 changes: 2 additions & 4 deletions careless/models/scaling/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def call(self, metadata_and_image_id, *args, **kwargs):
return result

class NeuralImageScaler(Scaler):
def __init__(self, image_layers, max_images, mlp_layers, mlp_width, leakiness=0.01, epsilon=1e-7):
def __init__(self, image_layers, max_images, mlp_layers, mlp_width, leakiness=0.01, epsilon=1e-7, scale_bijector=None):
super().__init__()
layers = []
if leakiness is None:
Expand All @@ -111,15 +111,13 @@ def __init__(self, image_layers, max_images, mlp_layers, mlp_width, leakiness=0.

self.image_layers = layers
from careless.models.scaling.nn import MetadataScaler
self.metadata_scaler = MetadataScaler(mlp_layers, mlp_width, leakiness, epsilon=epsilon)
self.metadata_scaler = MetadataScaler(mlp_layers, mlp_width, leakiness, epsilon=epsilon, scale_bijector=scale_bijector)

def call(self, inputs):
result = self.get_metadata(inputs)
image_id = self.get_image_id(inputs),

result = self.metadata_scaler.network(result)
# One could use this line to add a skip connection here
#result = result + self.get_metadata(inputs)

for layer in self.image_layers:
result = layer((result, image_id))
Expand Down
6 changes: 3 additions & 3 deletions careless/models/scaling/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, scale_bijector=None, epsilon=1e-7, **kwargs):
if scale_bijector is None:
self.scale_bijector = tfb.Chain([
tfb.Shift(epsilon),
tfb.Exp(),
tfb.Softplus(),
])
else:
self.scale_bijector = scale_bijector
Expand All @@ -29,7 +29,7 @@ class MetadataScaler(Scaler):
Neural network based scaler with simple dense layers.
This neural network outputs a normal distribution.
"""
def __init__(self, n_layers, width, leakiness=0.01, epsilon=1e-7):
def __init__(self, n_layers, width, leakiness=0.01, epsilon=1e-7, scale_bijector=None):
"""
Parameters
----------
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, n_layers, width, leakiness=0.01, epsilon=1e-7):

#The final layer converts the output to a Normal distribution
#tfp_layers.append(tfp.layers.IndependentNormal())
tfp_layers.append(NormalLayer(epsilon=epsilon))
tfp_layers.append(NormalLayer(epsilon=epsilon, scale_bijector=scale_bijector))

self.network = tfk.Sequential(mlp_layers)
self.distribution = tfk.Sequential(tfp_layers)
Expand Down
22 changes: 21 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_freeze_scales(off_file):

@pytest.mark.parametrize('clip_type', ['--clipvalue', '--clipnorm', '--global-clipnorm'])
def test_clipping(off_file, clip_type):
""" Test `--freeze-scales` for execution """
""" Test gradient clipping settings """
with TemporaryDirectory() as td:
out = td + '/out'
flags = f"mono --disable-gpu --iterations={niter} {clip_type}=1. dHKL,image_id"
Expand All @@ -183,3 +183,23 @@ def test_clipping(off_file, clip_type):
assert exists(out_file)


@pytest.mark.parametrize('scale_bijector', ['exp', 'softplus'])
@pytest.mark.parametrize('image_layers', [None, 2])
def test_scale_bijector(off_file, scale_bijector, image_layers):
""" Test scale bijector settings """
with TemporaryDirectory() as td:
out = td + '/out'
if image_layers is not None:
flags = f"mono --disable-gpu --image-layers={image_layers} --iterations={niter} --scale-bijector={scale_bijector} dHKL,image_id"
else:
flags = f"mono --disable-gpu --iterations={niter} --scale-bijector={scale_bijector} dHKL,image_id"

command = flags + f" {off_file} {out}"
from careless.parser import parser
parser = parser.parse_args(command.split())
run_careless(parser)

out_file = out + f"_0.mtz"
assert exists(out_file)


0 comments on commit 62c6859

Please sign in to comment.