Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly stagger JIT until first call #505

Merged
merged 1 commit into from
Jun 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,23 @@ def jit_or_pass_after_bitsize(jitable_fn):
A function which can be jit'd
"""

# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
return jit_or_passthrough(jitable_fn)
def staggered_jit(*args, **kwargs):
# This will only trigger if JAX is set
if use_jit and not config.x64_enabled:
# Warn that JAX 64-bit will being turned on
logger.warning(
"\n"
"******* JAX 64-bit mode is now on! *******\n"
"* JAX is now set to 64-bit mode! *\n"
"* This MAY cause problems with other *\n"
"* uses of JAX in the same code. *\n"
"******************************************\n"
)
config.update("jax_enable_x64", True)
jited_fn = jit_or_passthrough(jitable_fn)
return jited_fn(*args, **kwargs)

return staggered_jit


def validate_inputs(u_kn, N_k, f_k):
Expand Down