From 14a16b617f272156d47d5ca1123f4443fc96c58d Mon Sep 17 00:00:00 2001
From: Levi Naden <lnaden@vt.edu>
Date: Thu, 15 Jun 2023 16:03:41 -0400
Subject: [PATCH] Correctly stagger jit until first call

---
 pymbar/mbar_solvers.py | 30 +++++++++++++++++-------------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/pymbar/mbar_solvers.py b/pymbar/mbar_solvers.py
index 9ef56679..dd45f622 100644
--- a/pymbar/mbar_solvers.py
+++ b/pymbar/mbar_solvers.py
@@ -135,19 +135,23 @@ def jit_or_pass_after_bitsize(jitable_fn):
         A function which can be jit'd
     """
 
-    # This will only trigger if JAX is set
-    if use_jit and not config.x64_enabled:
-        # Warn that JAX 64-bit will being turned on
-        logger.warning(
-            "\n"
-            "******* JAX 64-bit mode is now on! *******\n"
-            "*     JAX is now set to 64-bit mode!     *\n"
-            "*   This MAY cause problems with other   *\n"
-            "*      uses of JAX in the same code.     *\n"
-            "******************************************\n"
-        )
-        config.update("jax_enable_x64", True)
-    return jit_or_passthrough(jitable_fn)
+    def staggered_jit(*args, **kwargs):
+        # This will only trigger if JAX is set
+        if use_jit and not config.x64_enabled:
+            # Warn that JAX 64-bit will being turned on
+            logger.warning(
+                "\n"
+                "******* JAX 64-bit mode is now on! *******\n"
+                "*     JAX is now set to 64-bit mode!     *\n"
+                "*   This MAY cause problems with other   *\n"
+                "*      uses of JAX in the same code.     *\n"
+                "******************************************\n"
+            )
+            config.update("jax_enable_x64", True)
+        jited_fn = jit_or_passthrough(jitable_fn)
+        return jited_fn(*args, **kwargs)
+
+    return staggered_jit
 
 
 def validate_inputs(u_kn, N_k, f_k):