Axes

Contents

Axes#

brainstate.typing.Axes#

Type for specifying axes along which operations should be performed.

Can be a single axis (integer) or multiple axes (sequence of integers). Used in reduction operations, reshaping, and other array manipulations.

Examples

>>> # Single axis
>>> axis1: Axes = 0
>>>
>>> # Multiple axes
>>> axis2: Axes = (0, 2)
>>>
>>> # All axes for global operations
>>> axis3: Axes = tuple(range(ndim))
>>>
>>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike:
...     return jnp.sum(array, axis=axes)