{
  "__type": "IngestedDoc",
  "__tag": 4010,
  "_content": {
    "Notes": {
      "__type": "Section",
      "__tag": 4015,
      "children": [
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "In order for this tag to be effective, the test function must be imported into the test module globals without its namespace; alternatively its namespace must be declared in a "
            },
            {
              "__type": "InlineCode",
              "__tag": 4051,
              "value": "lazy_xp_modules"
            },
            {
              "__type": "Text",
              "__tag": 4046,
              "value": " list in the test module globals."
            }
          ]
        },
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Example 1    "
            }
          ]
        },
        {
          "__type": "Code",
          "__tag": 4050,
          "value": "from mymodule import myfunc\n\nlazy_xp_function(myfunc)\n\ndef test_myfunc(xp):\n    x = myfunc(xp.asarray([1, 2]))",
          "execution_status": null
        },
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Example 2    "
            }
          ]
        },
        {
          "__type": "Code",
          "__tag": 4050,
          "value": "import mymodule\n\nlazy_xp_modules = [mymodule]\nlazy_xp_function(mymodule.myfunc)\n\ndef test_myfunc(xp):\n    x = mymodule.myfunc(xp.asarray([1, 2]))",
          "execution_status": null
        },
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "A test function can circumvent this monkey-patching system by using a namespace outside of the two above patterns. You need to sanitize your code to make sure this only happens intentionally."
            }
          ]
        },
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Example 1    "
            }
          ]
        },
        {
          "__type": "Code",
          "__tag": 4050,
          "value": "import mymodule\nfrom mymodule import myfunc\n\nlazy_xp_function(myfunc)\n\ndef test_myfunc(xp):\n    a = xp.asarray([1, 2])\n    b = myfunc(a)  # This is wrapped when xp=jax.numpy or xp=dask.array\n    c = mymodule.myfunc(a)  # This is not",
          "execution_status": null
        },
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Example 2    "
            }
          ]
        },
        {
          "__type": "Code",
          "__tag": 4050,
          "value": "import mymodule\n\nclass naked:\n    myfunc = mymodule.myfunc\n\nlazy_xp_modules = [mymodule]\nlazy_xp_function(mymodule.myfunc)\n\ndef test_myfunc(xp):\n    a = xp.asarray([1, 2])\n    b = mymodule.myfunc(a)  # This is wrapped when xp=jax.numpy or xp=dask.array\n    c = naked.myfunc(a)  # This is not",
          "execution_status": null
        }
      ],
      "title": [],
      "level": 0,
      "target": null
    },
    "Warns": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Raises": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Yields": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Methods": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Returns": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Summary": {
      "__type": "Section",
      "__tag": 4015,
      "children": [
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Tag a function to be tested on lazy backends."
            }
          ]
        }
      ],
      "title": [],
      "level": 0,
      "target": null
    },
    "Receives": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Warnings": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Attributes": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    },
    "Parameters": {
      "__type": "Section",
      "__tag": 4015,
      "children": [
        {
          "__type": "Parameters",
          "__tag": 4026,
          "children": [
            {
              "__type": "DocParam",
              "__tag": 4016,
              "name": "func",
              "annotation": "callable",
              "desc": [
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Function to be tested."
                    }
                  ]
                }
              ]
            },
            {
              "__type": "DocParam",
              "__tag": 4016,
              "name": "allow_dask_compute",
              "annotation": "bool | int, optional",
              "desc": [
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Whether "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " is allowed to internally materialize the Dask graph, or maximum number of times it is allowed to do so. This is typically triggered by "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "bool()"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": ", "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "float()"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": ", or "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "np.asarray()"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "."
                    }
                  ]
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Set to 1 if you are aware that "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " converts the input parameters to NumPy and want to let it do so at least for the time being, knowing that it is going to be extremely detrimental for performance."
                    }
                  ]
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "If a test needs values higher than 1 to pass, it is a canary that the conversion to NumPy/bool/float is happening multiple times, which translates to multiple computations of the whole graph. Short of making the function fully lazy, you should at least add explicit calls to "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "np.asarray()"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " early in the function. "
                    },
                    {
                      "__type": "Emphasis",
                      "__tag": 4047,
                      "children": [
                        {
                          "__type": "Text",
                          "__tag": 4046,
                          "value": "Note:"
                        }
                      ]
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " the counter of "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "allow_dask_compute"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " resets after each call to "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": ", so a test function that invokes "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " multiple times should still work with this parameter set to 1."
                    }
                  ]
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Set to True to allow "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " to materialize the graph an unlimited number of times."
                    }
                  ]
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Default: False, meaning that "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " must be fully lazy and never materialize the graph."
                    }
                  ]
                }
              ]
            },
            {
              "__type": "DocParam",
              "__tag": 4016,
              "name": "jax_jit",
              "annotation": "bool, optional",
              "desc": [
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Set to True to replace "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " with a smart variant of "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "jax.jit(func)"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " after calling the "
                    },
                    {
                      "__type": "CrossRef",
                      "__tag": 4002,
                      "value": "patch_lazy_xp_functions",
                      "reference": {
                        "__type": "LocalRef",
                        "__tag": 4022,
                        "kind": "module",
                        "path": "scipy._lib.array_api_extra.testing:patch_lazy_xp_functions"
                      },
                      "kind": "module"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " test helper with "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "xp=jax.numpy"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": ". This is the default behaviour. Set to False if "
                    },
                    {
                      "__type": "ParamRef",
                      "__tag": 4071,
                      "name": "func"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " is only compatible with eager (non-jitted) JAX."
                    }
                  ]
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Unlike with vanilla "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "jax.jit"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": ", all arguments and return types that are not JAX arrays are treated as static; the function can accept and return arbitrary wrappers around JAX arrays. This difference is because, in real life, most users won't wrap the function directly with "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "jax.jit"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": " but rather they will use it within their own code, which is itself then wrapped by "
                    },
                    {
                      "__type": "InlineCode",
                      "__tag": 4051,
                      "value": "jax.jit"
                    },
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": ", and internally consume the function's outputs."
                    }
                  ]
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "In other words, the pattern that is being tested is      "
                    }
                  ]
                },
                {
                  "__type": "Code",
                  "__tag": 4050,
                  "value": ">>> @jax.jit\n... def user_func(x):\n...     y = user_prepares_inputs(x)\n...     z = func(y, some_static_arg=True)\n...     return user_consumes(z)",
                  "execution_status": null
                },
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Default: True."
                    }
                  ]
                }
              ]
            },
            {
              "__type": "DocParam",
              "__tag": 4016,
              "name": "static_argnums",
              "annotation": "",
              "desc": [
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Deprecated; ignored"
                    }
                  ]
                }
              ]
            },
            {
              "__type": "DocParam",
              "__tag": 4016,
              "name": "static_argnames",
              "annotation": "",
              "desc": [
                {
                  "__type": "Paragraph",
                  "__tag": 4045,
                  "children": [
                    {
                      "__type": "Text",
                      "__tag": 4046,
                      "value": "Deprecated; ignored"
                    }
                  ]
                }
              ]
            }
          ]
        }
      ],
      "title": [],
      "level": 0,
      "target": null
    },
    "Extended Summary": {
      "__type": "Section",
      "__tag": 4015,
      "children": [
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Tag a function so that when any tests are executed with "
            },
            {
              "__type": "InlineCode",
              "__tag": 4051,
              "value": "xp=jax.numpy"
            },
            {
              "__type": "Text",
              "__tag": 4046,
              "value": " the function is replaced with a jitted version of itself, and when it is executed with "
            },
            {
              "__type": "InlineCode",
              "__tag": 4051,
              "value": "xp=dask.array"
            },
            {
              "__type": "Text",
              "__tag": 4046,
              "value": " the function will raise if it attempts to materialize the graph. This will be later expanded to provide test coverage for other lazy backends."
            }
          ]
        },
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "In order for the tag to be effective, the test or a fixture must call "
            },
            {
              "__type": "CrossRef",
              "__tag": 4002,
              "value": "patch_lazy_xp_functions",
              "reference": {
                "__type": "LocalRef",
                "__tag": 4022,
                "kind": "module",
                "path": "scipy._lib.array_api_extra.testing:patch_lazy_xp_functions"
              },
              "kind": "module"
            },
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "."
            }
          ]
        }
      ],
      "title": [],
      "level": 0,
      "target": null
    },
    "Other Parameters": {
      "__type": "Section",
      "__tag": 4015,
      "children": [],
      "title": [],
      "level": 0,
      "target": null
    }
  },
  "_ordered_sections": [
    "Summary",
    "Extended Summary",
    "Parameters",
    "Attributes",
    "Methods",
    "Returns",
    "Yields",
    "Receives",
    "Other Parameters",
    "Raises",
    "Warns",
    "Warnings",
    "Notes"
  ],
  "item_file": "/scipy/_lib/array_api_extra/testing.py",
  "item_line": 51,
  "item_type": "function",
  "aliases": [
    "scipy.differentiate.xpx.testing.lazy_xp_function"
  ],
  "example_section_data": {
    "__type": "Section",
    "__tag": 4015,
    "children": [
      {
        "__type": "Text",
        "__tag": 4046,
        "value": "In ``test_mymodule.py``::\n\n  from array_api_extra.testing import lazy_xp_function from mymodule import myfunc\n\n  lazy_xp_function(myfunc)\n\n  def test_myfunc(xp):\n      a = xp.asarray([1, 2])\n      # When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)`\n      # When xp=dask.array, crash on compute() or persist()\n      b = myfunc(a)"
      }
    ],
    "title": [],
    "level": 0,
    "target": null
  },
  "see_also": [
    {
      "__type": "SeeAlsoItem",
      "__tag": 4028,
      "name": {
        "__type": "CrossRef",
        "__tag": 4002,
        "value": "jax.jit",
        "reference": {
          "__type": "RefInfo",
          "__tag": 4000,
          "module": "current-module",
          "version": "current-version",
          "kind": "to-resolve",
          "path": "jax.jit"
        },
        "kind": "module"
      },
      "descriptions": [
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "JAX function to compile a function for performance."
            }
          ]
        }
      ],
      "type": null
    },
    {
      "__type": "SeeAlsoItem",
      "__tag": 4028,
      "name": {
        "__type": "CrossRef",
        "__tag": 4002,
        "value": "patch_lazy_xp_functions",
        "reference": {
          "__type": "LocalRef",
          "__tag": 4022,
          "kind": "module",
          "path": "scipy._lib.array_api_extra.testing:patch_lazy_xp_functions"
        },
        "kind": "module"
      },
      "descriptions": [
        {
          "__type": "Paragraph",
          "__tag": 4045,
          "children": [
            {
              "__type": "Text",
              "__tag": 4046,
              "value": "Companion function to call from the test or fixture."
            }
          ]
        }
      ],
      "type": null
    }
  ],
  "signature": {
    "__type": "SignatureNode",
    "__tag": 4029,
    "kind": "function",
    "parameters": [
      {
        "__type": "SigParam",
        "__tag": 4030,
        "name": "func",
        "annotation": "Callable[..., Any]",
        "kind": "POSITIONAL_OR_KEYWORD",
        "default": {
          "__type": "Empty",
          "__tag": 4031
        }
      },
      {
        "__type": "SigParam",
        "__tag": 4030,
        "name": "allow_dask_compute",
        "annotation": "bool | int",
        "kind": "KEYWORD_ONLY",
        "default": "False"
      },
      {
        "__type": "SigParam",
        "__tag": 4030,
        "name": "jax_jit",
        "annotation": "bool",
        "kind": "KEYWORD_ONLY",
        "default": "True"
      },
      {
        "__type": "SigParam",
        "__tag": 4030,
        "name": "static_argnums",
        "annotation": "Deprecated",
        "kind": "KEYWORD_ONLY",
        "default": "Deprecated.DEPRECATED"
      },
      {
        "__type": "SigParam",
        "__tag": 4030,
        "name": "static_argnames",
        "annotation": "Deprecated",
        "kind": "KEYWORD_ONLY",
        "default": "Deprecated.DEPRECATED"
      }
    ],
    "return_annotation": "None",
    "target_name": "lazy_xp_function"
  },
  "references": null,
  "qa": "scipy._lib.array_api_extra.testing:lazy_xp_function",
  "arbitrary": [],
  "local_refs": [
    "allow_dask_compute",
    "func",
    "jax_jit",
    "static_argnames",
    "static_argnums"
  ]
}