brainevent.jaxinfo_to_warpinfo#
- brainevent.jaxinfo_to_warpinfo(jax_info)[source]#
Convert a
jax.ShapeDtypeStructto a Warp array type descriptor.Takes a JAX shape-and-dtype specification and creates the corresponding Warp array type with matching data type and dimensionality. This is useful when building Warp kernel signatures from JAX output specifications.
- Parameters:
jax_info (
ShapeDtypeStruct) – A JAX structure containingshape,dtype, andndimattributes describing an array.- Returns:
A Warp array type with matching data type and number of dimensions, suitable for use in Warp kernel definitions.
- Return type:
warp.types.array
- Raises:
ImportError – If the
warppackage is not installed (propagated fromjaxtype_to_warptype()).ValueError – If the dtype in jax_info is not supported by Warp (propagated from
jaxtype_to_warptype()).
See also
jaxtype_to_warptypeConvert a single dtype to a Warp type.
check_warp_installedVerify that Warp is available.
Notes
The resulting Warp array type is constructed via
warp.array(dtype=..., ndim=...)which creates a Warp type descriptor (not an actual array). This is typically used in Warp kernel function signatures to define input/output types.Examples
>>> import jax >>> import numpy as np >>> info = jax.ShapeDtypeStruct(shape=(32, 32), dtype=np.float32) >>> warp_arr_type = jaxinfo_to_warpinfo(info)