bundles / scipy latest / scipy / _lib / array_api_extra / testing / patch_lazy_xp_functions
function
scipy._lib.array_api_extra.testing:patch_lazy_xp_functions
Signature
def patch_lazy_xp_functions ( request : pytest.FixtureRequest , monkeypatch : pytest.MonkeyPatch | None = None , * , xp : ModuleType ) → contextlib.AbstractContextManager[None] Summary
Test lazy execution of functions tagged with lazy_xp_function.
Extended Summary
If xp==jax.numpy, search for all functions which have been tagged with lazy_xp_function in the globals of the module that defines the current test, as well as in the lazy_xp_modules list in the globals of the same module, and wrap them with jax.jit. Unwrap them at the end of the test.
If xp==dask.array, wrap the functions with a decorator that disables compute() and persist() and ensures that exceptions and warnings are raised eagerly.
This function should be typically called by your library's xp fixture that runs tests on multiple backends
@pytest.fixture(params=[ numpy, array_api_strict, pytest.param(jax.numpy, marks=pytest.mark.thread_unsafe), pytest.param(dask.array, marks=pytest.mark.thread_unsafe), ]) def xp(request): with patch_lazy_xp_functions(request, xp=request.param): yield request.param
but it can be otherwise be called by the test itself too.
Parameters
request: pytest.FixtureRequestPytest fixture, as acquired by the test itself or by one of its fixtures.
monkeypatch: pytest.MonkeyPatchDeprecated
xp: array_namespaceArray namespace to be tested.
Notes
This context manager monkey-patches modules and as such is thread unsafe on Dask and JAX. If you run your test suite with pytest-run-parallel, you should mark these backends with @pytest.mark.thread_unsafe, as shown in the example above.
See also
- lazy_xp_function
Tag a function to be tested on lazy backends.
- pytest.FixtureRequest
requesttest function parameter.
Aliases
-
scipy.differentiate.xpx.testing.patch_lazy_xp_functions