bundles / scipy 1.17.1 / scipy / _lib / array_api_extra / testing / lazy_xp_function
function
scipy._lib.array_api_extra.testing:lazy_xp_function
Signature
def lazy_xp_function ( func : Callable[..., Any] , * , allow_dask_compute : bool | int = False , jax_jit : bool = True , static_argnums : Deprecated = Deprecated.DEPRECATED , static_argnames : Deprecated = Deprecated.DEPRECATED ) → None Summary
Tag a function to be tested on lazy backends.
Extended Summary
Tag a function so that when any tests are executed with xp=jax.numpy the function is replaced with a jitted version of itself, and when it is executed with xp=dask.array the function will raise if it attempts to materialize the graph. This will be later expanded to provide test coverage for other lazy backends.
In order for the tag to be effective, the test or a fixture must call patch_lazy_xp_functions.
Parameters
func: callableFunction to be tested.
allow_dask_compute: bool | int, optionalWhether
funcis allowed to internally materialize the Dask graph, or maximum number of times it is allowed to do so. This is typically triggered bybool(),float(), ornp.asarray().Set to 1 if you are aware that
funcconverts the input parameters to NumPy and want to let it do so at least for the time being, knowing that it is going to be extremely detrimental for performance.If a test needs values higher than 1 to pass, it is a canary that the conversion to NumPy/bool/float is happening multiple times, which translates to multiple computations of the whole graph. Short of making the function fully lazy, you should at least add explicit calls to
np.asarray()early in the function. Note: the counter ofallow_dask_computeresets after each call tofunc, so a test function that invokesfuncmultiple times should still work with this parameter set to 1.Set to True to allow
functo materialize the graph an unlimited number of times.Default: False, meaning that
funcmust be fully lazy and never materialize the graph.jax_jit: bool, optionalSet to True to replace
funcwith a smart variant ofjax.jit(func)after calling the patch_lazy_xp_functions test helper withxp=jax.numpy. This is the default behaviour. Set to False iffuncis only compatible with eager (non-jitted) JAX.Unlike with vanilla
jax.jit, all arguments and return types that are not JAX arrays are treated as static; the function can accept and return arbitrary wrappers around JAX arrays. This difference is because, in real life, most users won't wrap the function directly withjax.jitbut rather they will use it within their own code, which is itself then wrapped byjax.jit, and internally consume the function's outputs.In other words, the pattern that is being tested is
>>> @jax.jit ... def user_func(x): ... y = user_prepares_inputs(x) ... z = func(y, some_static_arg=True) ... return user_consumes(z)
Default: True.
static_argnumsDeprecated; ignored
static_argnamesDeprecated; ignored
Notes
In order for this tag to be effective, the test function must be imported into the test module globals without its namespace; alternatively its namespace must be declared in a lazy_xp_modules list in the test module globals.
Example 1
from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): x = myfunc(xp.asarray([1, 2]))
Example 2
import mymodule lazy_xp_modules = [mymodule] lazy_xp_function(mymodule.myfunc) def test_myfunc(xp): x = mymodule.myfunc(xp.asarray([1, 2]))
A test function can circumvent this monkey-patching system by using a namespace outside of the two above patterns. You need to sanitize your code to make sure this only happens intentionally.
Example 1
import mymodule from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array c = mymodule.myfunc(a) # This is not
Example 2
import mymodule class naked: myfunc = mymodule.myfunc lazy_xp_modules = [mymodule] lazy_xp_function(mymodule.myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array c = naked.myfunc(a) # This is not
Examples
In ``test_mymodule.py``:: from array_api_extra.testing import lazy_xp_function from mymodule import myfunc lazy_xp_function(myfunc) def test_myfunc(xp): a = xp.asarray([1, 2]) # When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)` # When xp=dask.array, crash on compute() or persist() b = myfunc(a)See also
- jax.jit
JAX function to compile a function for performance.
- patch_lazy_xp_functions
Companion function to call from the test or fixture.
Aliases
-
scipy.differentiate.xpx.testing.lazy_xp_function