Skip to content

Commit

Permalink
Merge pull request #505 from Lnaden/jax_stagger_jit
Browse files Browse the repository at this point in the history
Correctly stagger JIT until first call
  • Loading branch information
Lnaden authored Jun 16, 2023
2 parents a5fa114 + 14a16b6 commit 63084ad
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,23 @@ def jit_or_pass_after_bitsize(jitable_fn):
A function which can be jit'd
"""

# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
return jit_or_passthrough(jitable_fn)
def staggered_jit(*args, **kwargs):
# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
jited_fn = jit_or_passthrough(jitable_fn)
return jited_fn(*args, **kwargs)

return staggered_jit


def validate_inputs(u_kn, N_k, f_k):
Expand Down

0 comments on commit 63084ad

Please sign in to comment.