ArrayLike

Contents

ArrayLike#

brainstate.typing.ArrayLike#

Union of all objects that can be implicitly converted to a JAX array.

This type is designed for JAX compatibility and excludes arbitrary sequences and string data that numpy.typing.ArrayLike would include. It represents data that can be safely converted to arrays without ambiguity.

Components#

jax.Array

Native JAX arrays.

np.ndarray

NumPy arrays that can be converted to JAX arrays.

np.bool_, np.number

NumPy scalar types (bool, int8, float32, etc.).

bool, int, float, complex

Python built-in scalar types.

u.Quantity

BrainUnit quantities with physical units.

Examples

>>> def process_data(data: ArrayLike) -> jax.Array:
...     '''Convert input to JAX array and process it.'''
...     array = jnp.asarray(data)
...     return array * 2
>>>
>>> # Valid inputs
>>> process_data(jnp.array([1, 2, 3]))      # JAX array
>>> process_data(np.array([1, 2, 3]))       # NumPy array
>>> process_data([1, 2, 3])                 # Python list (via numpy)
>>> process_data(42)                        # Python scalar
>>> process_data(np.float32(3.14))          # NumPy scalar
>>> process_data(1.5 * u.second)            # Quantity with units