JAX JIT and the Case of the Disappearing Cache: Understanding `PyTreeNodes` and `functools.cached_property`
When working with JAX, especially when using jax.jit for performance, you might encounter some surprising behaviors if you’re not careful about how JAX “sees” and handles your objects. One such area is the interaction between jax.jit and Python’s @functools.cached_property. You might find that your property isn’t as “cached” as you expect within JIT-compiled functions.
Let’s dive into this with some test code.
The Core Players
jax.jit: This powerful decorator compiles your Python functions into highly optimized XLA code. A key aspect of JIT is tracing. When a function is JIT-compiled, JAX calls it once with abstract representations of the inputs (called “Tracers”). It records all operations performed on these Tracers. The sequence of operations is then compiled.flax.struct.PyTreeNode: These are classes that JAX knows how to treat as “PyTrees.” This means JAX can iterate over their attributes (specified as type-annotated fields) and treat them as children in the PyTree structure. This is crucial for JIT, as it needs to know which parts of your object contain JAX arrays or other JAX-compatible data.@functools.cached_property: This decorator transforms a method of a class into a property whose value is computed once and then cached as an ordinary attribute in the instance’s__dict__. Subsequent accesses retrieve the value from the__dict__.
Scenario 1: jax.jit on @functools.cached_property
import functools
import flax.struct
import jax
import jax.numpy as jnp
class Foo(flax.struct.PyTreeNode):
a: jax.Array
@functools.cached_property
@jax.jit # The property itself is JIT-compiled
def b(self) -> jax.Array:
jax.debug.print("Computing b ...")
return jnp.sum(self.a)
@jax.jit
def use_b(foo: Foo) -> jax.Array:
# print(foo.__dict__) # For debugging, would show only 'a' as a Tracer here
return foo.bLet’s look at the behaviors:
foo.b, foo.b (Outside JIT context)
foo = Foo(a=jnp.ones((3, 3)))
print(foo.b) # Prints "Computing b ...", b is computed
# `b` is now cached in `foo.__dict__`
print(foo.b) # Does NOT print "Computing b ...", uses cached valueThis works as expected. The first call to foo.b computes the sum, prints the message, and stores the result in foo.__dict__['b']. The second call finds b in __dict__ and returns it directly.
foo.b, use_b(foo) (Mixing JIT and non-JIT)
foo = Foo(a=jnp.ones((3, 3)))
print(foo.b) # Prints "Computing b ...", b is computed and cached on original foo
# foo.__dict__ now contains 'b'
print(use_b(foo)) # Prints "Computing b ..." AGAIN!Why the recomputation?
- The first
foo.bcachesbin the__dict__of the originalfooobject. - When
use_b(foo)is called, JAX traces it. Thefooobject passed intouse_bis treated as aPyTreeNode. JAX creates an internal representation offoofor tracing purposes. This internalfooonly contains the JAX-registered fields (in this case,a, which becomes a Tracer). - Crucially, the
__dict__of the originalfoo(and its cachedb) is not part of the JAX-registered fields. So, the internalfooinsideuse_bdoesn’t havebin its (non-existent or empty)__dict__. - When
foo.bis accessed insideuse_b, it’s operating on this internal, tracedfoo. Sincebisn’t cached there, thebproperty’s code (which is JIT-compiled itself) executes again, printing “Computing b …”.
use_b(foo), use_b(foo) (Multiple JIT calls)
foo = Foo(a=jnp.ones((3, 3)))
print(use_b(foo)) # Prints "Computing b ..."
print(use_b(foo)) # Prints "Computing b ..." AGAIN!Each call to use_b(foo) JIT-compiles (on the first call) or reuses the compiled function. Each time, foo is passed as an input. The caching that @functools.cached_property tries to do (modifying __dict__) happens on an internal, temporary representation of foo within that JITted execution. This cache does not persist across separate calls to use_b because the __dict__ is not a returned value or a JAX-managed state.
use_b(foo), foo.b (JIT call then outside JIT)
foo = Foo(a=jnp.ones((3, 3)))
print(use_b(foo)) # Prints "Computing b ..."
print(foo.b) # Prints "Computing b ..." AGAIN!use_b(foo)computesbinternally, as explained. This does not affect the__dict__of the original, externalfooobject.- When
foo.bis called on the originalfooobject, its__dict__does not yet containb(assuming it wasn’t called beforeuse_b), so it computes again.
The Root Cause for Scenario 1
@functools.cached_property relies on Python’s standard object attribute storage (__dict__). JAX’s tracing mechanism for PyTreeNodes only considers the explicitly defined fields (like a). It doesn’t know about or track changes to the __dict__. When the property b itself is JIT-compiled, the attempt to cache by modifying self.__dict__ is a side effect that JAX’s JIT compilation model doesn’t reliably preserve or propagate in the way cached_property expects.
Scenario 2: Manual Caching with a PyTreeNode Field
Now, let’s look at your second approach, where you manually manage the cache using a field that JAX does know about.
class Foo(flax.struct.PyTreeNode, frozen=False): # frozen=False allows modification
a: jax.Array
_c_cache: jax.Array = flax.struct.field(default=None) # JAX is aware of this field
@property
# NOT JIT-compiled itself
def c(self) -> jax.Array:
if self._c_cache is not None:
return self._c_cache
jax.debug.print("Computing c ...")
self._c_cache = jnp.sum(self.a) # Modify the registered field
return self._c_cache
@jax.jit
def use_c(foo: Foo) -> jax.Array:
# ic(foo.__dict__) # Would show 'a' and '_c_cache' as Tracers
return foo.cfoo.c, foo.c (Outside JIT context)
foo = Foo(a=jnp.ones((3, 3)))
print(foo.c) # Prints "Computing c ...", _c_cache is populated
print(foo.c) # Does NOT print "Computing c ...", uses _c_cacheThis works because _c_cache is a regular attribute on the foo instance. The frozen=False on PyTreeNode allows this attribute to be modified.
foo.c, use_c(foo) (Mixing non-JIT property and JIT function)
foo = Foo(a=jnp.ones((3, 3)))
print(foo.c) # Prints "Computing c ...", original foo._c_cache is populated
# Now, original foo has _c_cache set.
print(use_c(foo)) # Does NOT print "Computing c ..."This is a key difference!
foo.cis called on the originalfoo. It computescand stores it infoo._c_cache.- When
use_c(foo)is called,foois passed to the JIT-compiled function. JAX tracesfoo, including its registered fieldsaand_c_cache. The value of_c_cache(which is now the computed sum) is part of the traced inputs. - Inside
use_c, whenfoo.cis accessed,selfis the tracedfoo. Its_c_cachefield already holds the sum (as a traced value). Theif self._c_cache is not None:check (which JAX can handle forNoneor JAX arrays) passes, and the cached value is returned. No recomputation.
use_c(foo), use_c(foo) (Multiple JIT calls)
foo = Foo(a=jnp.ones((3, 3)))
print(use_c(foo)) # Prints "Computing c ..."
print(use_c(foo)) # Prints "Computing c ..." AGAIN!This recomputes, similar to the use_b case, but for a slightly different reason regarding state.
- First call to
use_c(foo): The originalfooobject (where_c_cacheis initiallyNone) is passed. Insideuse_c,foo.cis called. - The
if self._c_cache is not None:check is performed on the traced version of_c_cache, which isNone. So, “Computing c …” is printed. - The line
self._c_cache = jnp.sum(self.a)executes. Since_c_cacheis a JAX-aware field, this assignment updates the_c_cacheof the internal, tracedfooobject for the duration of this JITted execution. - However, this modification to the internal
_c_cachewithinuse_cdoes not affect the_c_cacheof the original, externalfooobject unlessuse_cwere to return the modifiedfooand you were to use that returned instance. - Second call to
use_c(foo): The same originalfooobject is passed in again. Its_c_cacheis stillNone(because the first call touse_cdidn’t change the externalfoo). So, the computation happens again.
use_c(foo), foo.c (JIT call then outside JIT)
foo = Foo(a=jnp.ones((3, 3)))
print(use_c(foo)) # Prints "Computing c ..."
print(foo.c) # Prints "Computing c ..." AGAIN!use_c(foo)is called. As above, “Computing c …” is printed. The_c_cacheof the externalfooobject remainsNone.- When
foo.cis called on the original, externalfooobject, its_c_cacheis stillNone, so it computes again.
Why the Manual Cache (_c_cache) Behaves This Way
- Visibility to JAX: Because
_c_cacheis a registered field in thePyTreeNode, JAX includes it in its tracing process. Whenfoois an input to a JITted function, the current value offoo._c_cacheis made available to the traced function. - Side Effects in JIT: JITted functions are ideally pure from JAX’s perspective. Modifying an input object’s attribute (like
_c_cache = ...inside the propertycwhen called fromuse_c) is a side effect.- When
cis called fromuse_c, the propertycitself is not JIT-compiled. It executes as regular Python code during the JIT trace or execution ofuse_c. - The assignment
self._c_cache = ...happens on the traced representation offoo. This update is visible within the current execution of the JITted functionuse_c. - However, this internal update doesn’t automatically propagate back to the original Python object
foothat was passed intouse_c. JAX usually requires you to explicitly return modified objects if you want to see their changes outside the JITted function.
- When
Key Takeaways & Best Practices
- JIT and Instance State:
jax.jitprimarily cares about the data in the registered fields of yourPyTreeNodes. Python’s internal mechanisms like__dict__(used byfunctools.cached_property) are generally opaque to JIT and are not part of the traced state of aPyTreeNode. - Caching Scope:
@functools.cached_property: Caching works reliably outside JIT. Inside JIT, or when the property itself is JITted, its reliance on__dict__makes it behave unexpectedly because__dict__isn’t a JAX-traced field.- Manual Cache Field: If the cache is a JAX-registered field (like
_c_cache):- If populated before calling the JITted function, the JITted function will see the cached value.
- If the caching logic (property setter) is called from within a JITted function, it modifies an internal, traced version of the object. This change doesn’t affect the original external object or persist to subsequent independent calls to the JITted function with the same original object, unless you explicitly return the modified object from the JITted function and use that returned instance for further operations.
- Purity and Side Effects: JAX prefers pure functions (functions that don’t have side effects). Modifying an object’s state within a JITted function is a side effect. While JAX can sometimes handle updates to its own registered fields (like
_c_cache), it’s often clearer to design JITted functions to take inputs and produce outputs, with state updates handled by passing modified objects out of the function.
In Conclusion
When you want to cache computations in JAX objects that interact with jax.jit:
- Avoid
@functools.cached_propertyif the property is accessed within JIT-compiled code or if the property itself is JIT-compiled, as its caching mechanism is invisible to JAX’s tracing. - Using a manual cache field (like
_c_cache) makes the cached data visible to JAX. - Be mindful that modifications to such fields within a JITted function call are typically on an internal representation and won’t update the original external object unless that object is returned by the JITted function.
- For robust state management within JIT, especially for things like model parameters or optimizer states, consider Flax’s
Modulesystem, which has its own mechanisms for handling state (variables,sow) that are designed to work correctly with JAX transformations.
Understanding how JAX traces and handles PyTreeNodes is key to avoiding these caching conundrums and writing effective JAX code!