{ } Raw JSON

bundles / scipy latest / scipy / _lib / array_api_extra / _lib / _at / at

class

scipy._lib.array_api_extra._lib._at:at

source: /scipy/_lib/array_api_extra/_lib/_at.py :61

Signature

class   at ( x : Array idx : SetIndex | Undef = Undef.UNDEF / )  →  None

Members

Summary

Update operations for read-only arrays.

Extended Summary

This implements jax.numpy.ndarray.at for all writeable backends (those that support __setitem__) and routes to the .at[] method for JAX arrays.

Parameters

x : array

Input array.

idx : index, optional

Only array API standard compliant indices are supported.

You may use two alternate syntaxes

>>> import array_api_extra as xpx
>>> xpx.at(x, idx).set(value)  # or add(value), etc.
>>> xpx.at(x)[idx].set(value)
copy : bool, optional

None (default)

The array parameter may be modified in place if it is possible and beneficial for performance. You should not reuse it after calling this function.

True

Ensure that the inputs are not modified.

False

Ensure that the update operation writes back to the input. Raise ValueError if a copy cannot be avoided.

xp : array_namespace, optional

The standard-compatible namespace for x. Default: infer.

Returns

: Updated input array.

Warnings

  • When you omit the copy parameter, you should never reuse the parameter

array later on; ideally, you should reassign it immediately

>>> import array_api_extra as xpx
>>> x = xpx.at(x, 0).set(2)

The above best practice pattern ensures that the behaviour won't change depending on whether x is writeable or not, as the original x object is dereferenced as soon as xpx.at returns; this way there is no risk to accidentally update it twice.

On the reverse, the anti-pattern below must be avoided, as it will result in different behaviour on read-only versus writeable arrays

>>> x = xp.asarray([0, 0, 0])
>>> y = xpx.at(x, 0).set(2)
>>> z = xpx.at(x, 1).set(3)

In the above example, both calls to xpx.at update x in place if possible. This causes the behaviour to diverge depending on whether x is writeable or not:

  • If x is writeable, then after the snippet above you'll have x == y == z == [2, 3, 0]

  • If x is read-only, then you'll end up with x == [0, 0, 0], y == [2, 0, 0] and z == [0, 3, 0].

The correct pattern to use if you want diverging outputs from the same input is to enforce copies

>>> x = xp.asarray([0, 0, 0])
>>> y = xpx.at(x, 0).set(2, copy=True)  # Never updates x
>>> z = xpx.at(x, 1).set(3)  # May or may not update x in place
>>> del x  # avoid accidental reuse of x as we don't know its state anymore
  • The array API standard does not support integer array indices.

The behaviour of update methods when the index is an array of integers is undefined and will vary between backends; this is particularly true when the index contains multiple occurrences of the same index, e.g.

>>> import numpy as np
>>> import jax.numpy as jnp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
array([124])
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
Array([125], dtype=int32)

Notes

sparse, as well as read-only arrays from libraries not explicitly covered by array-api-compat, are not supported by update methods.

Boolean masks are supported on Dask and jitted JAX arrays exclusively when idx has the same shape as x and y is 0-dimensional. Note that this support is not available in JAX's native x.at[mask].set(y).

This pattern

>>> mask = m(x)
>>> x[mask] = f(x[mask])

Can't be replaced by at, as it won't work on Dask and JAX inside jax.jit

>>> mask = m(x)
>>> x = xpx.at(x, mask).set(f(x[mask])  # Crash on Dask and jax.jit

You should instead use

>>> x = xp.where(m(x), f(x), x)

Examples

Given either of these equivalent expressions::
import array_api_extra as xpx
x = xpx.at(x)[1].add(2)
x = xpx.at(x, 1).add(2)
If x is a JAX array, they are the same as::
x = x.at[1].add(2)
If x is a read-only NumPy array, they are the same as::
x = x.copy()
x[1] += 2
For other known backends, they are the same as::
x[1] += 2

See also

jax.numpy.ndarray.at

Equivalent array method in JAX.

Aliases

  • scipy.differentiate.xpx.at