brainevent.jaxinfo_to_warpinfo

brainevent.jaxinfo_to_warpinfo#

brainevent.jaxinfo_to_warpinfo(jax_info)[source]#

Convert a jax.ShapeDtypeStruct to a Warp array type descriptor.

Takes a JAX shape-and-dtype specification and creates the corresponding Warp array type with matching data type and dimensionality. This is useful when building Warp kernel signatures from JAX output specifications.

Parameters:

jax_info (ShapeDtypeStruct) – A JAX structure containing shape, dtype, and ndim attributes describing an array.

Returns:

A Warp array type with matching data type and number of dimensions, suitable for use in Warp kernel definitions.

Return type:

warp.types.array

Raises:

See also

jaxtype_to_warptype

Convert a single dtype to a Warp type.

check_warp_installed

Verify that Warp is available.

Notes

The resulting Warp array type is constructed via warp.array(dtype=..., ndim=...) which creates a Warp type descriptor (not an actual array). This is typically used in Warp kernel function signatures to define input/output types.

Examples

>>> import jax
>>> import numpy as np
>>> info = jax.ShapeDtypeStruct(shape=(32, 32), dtype=np.float32)
>>> warp_arr_type = jaxinfo_to_warpinfo(info)