When working with JAX, especially when using jax.jit for performance, you might encounter some surprising behaviors if you’re not careful about how JAX “sees” and handles your objects. One such area is the interaction between jax.jit and Python’s @functools.cached_property. You might find that your property isn’t as “cached” as you expect within JIT-compiled functions.
6/4/25About 7 min