Axes#
- brainstate.typing.Axes#
Type for specifying axes along which operations should be performed.
Can be a single axis (integer) or multiple axes (sequence of integers). Used in reduction operations, reshaping, and other array manipulations.
Examples
>>> # Single axis >>> axis1: Axes = 0 >>> >>> # Multiple axes >>> axis2: Axes = (0, 2) >>> >>> # All axes for global operations >>> axis3: Axes = tuple(range(ndim)) >>> >>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike: ... return jnp.sum(array, axis=axes)