diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py index 9ef56679..dd45f622 100644 --- a/pymbar/mbar_solvers.py +++ b/pymbar/mbar_solvers.py @@ -135,19 +135,23 @@ def jit_or_pass_after_bitsize(jitable_fn): A function which can be jit'd """ - # This will only trigger if JAX is set - if use_jit and not config.x64_enabled: - # Warn that JAX 64-bit will being turned on - logger.warning( - "\n" - "******* JAX 64-bit mode is now on! *******\n" - "* JAX is now set to 64-bit mode! *\n" - "* This MAY cause problems with other *\n" - "* uses of JAX in the same code. *\n" - "******************************************\n" - ) - config.update("jax_enable_x64", True) - return jit_or_passthrough(jitable_fn) + def staggered_jit(*args, **kwargs): + # This will only trigger if JAX is set + if use_jit and not config.x64_enabled: + # Warn that JAX 64-bit will being turned on + logger.warning( + "\n" + "******* JAX 64-bit mode is now on! *******\n" + "* JAX is now set to 64-bit mode! *\n" + "* This MAY cause problems with other *\n" + "* uses of JAX in the same code. *\n" + "******************************************\n" + ) + config.update("jax_enable_x64", True) + jited_fn = jit_or_passthrough(jitable_fn) + return jited_fn(*args, **kwargs) + + return staggered_jit def validate_inputs(u_kn, N_k, f_k):