bundles / scipy 1.17.1 / scipy / _lib / array_api_extra / _lib / _utils / _helpers / jax_autojit
function
scipy._lib.array_api_extra._lib._utils._helpers:jax_autojit
source: /scipy/_lib/array_api_extra/_lib/_utils/_helpers.py :560
Signature
def jax_autojit ( func : Callable[P, T] ) → Callable[P, T] Summary
Wrap func with jax.jit, with the following differences:
Extended Summary
Python scalar arguments and return values are not automatically converted to
jax.Arrayobjects.All non-array arguments are automatically treated as static. Unlike
jax.jit, static arguments must be either hashable or serializable withpickle.Unlike
jax.jit, non-array arguments and return values are not limited to tuple/list/dict, but can be any object serializable withpickle.Automatically descend into non-array arguments and find
jax.Arrayobjects inside them, then rebuild the arguments when enteringfunc, swapping the JAX concrete arrays with tracer objects.Automatically descend into non-array return values and find
jax.Arrayobjects inside them, then rebuild them downstream of exiting the JIT, swapping the JAX tracer objects with concrete arrays.Returned iterators are immediately completely consumed.
Notes
These are useful choices for testing purposes only, which is how this function is intended to be used. The output of jax.jit is a C++ level callable, that directly dispatches to the compiled kernel after the initial call. In comparison, jax_autojit incurs a much higher dispatch time.
Additionally, consider
def f(x: Array, y: float, plus: bool) -> Array: return x + y if plus else x - y j1 = jax.jit(f, static_argnames="plus") j2 = jax_autojit(f)
In the above example, j2 requires a lot less setup to be tested effectively than j1, but on the flip side it means that it will be re-traced for every different value of y, which likely makes it not fit for purpose in production.
See also
- jax.jit
JAX JIT compilation function.
Aliases
-
scipy.differentiate.xpx.testing.jax_autojit