ArrayLike#
- brainstate.typing.ArrayLike#
Union of all objects that can be implicitly converted to a JAX array.
This type is designed for JAX compatibility and excludes arbitrary sequences and string data that numpy.typing.ArrayLike would include. It represents data that can be safely converted to arrays without ambiguity.
Components#
- jax.Array
Native JAX arrays.
- np.ndarray
NumPy arrays that can be converted to JAX arrays.
- np.bool_, np.number
NumPy scalar types (bool, int8, float32, etc.).
- bool, int, float, complex
Python built-in scalar types.
- u.Quantity
BrainUnit quantities with physical units.
Examples
>>> def process_data(data: ArrayLike) -> jax.Array: ... '''Convert input to JAX array and process it.''' ... array = jnp.asarray(data) ... return array * 2 >>> >>> # Valid inputs >>> process_data(jnp.array([1, 2, 3])) # JAX array >>> process_data(np.array([1, 2, 3])) # NumPy array >>> process_data([1, 2, 3]) # Python list (via numpy) >>> process_data(42) # Python scalar >>> process_data(np.float32(3.14)) # NumPy scalar >>> process_data(1.5 * u.second) # Quantity with units