brainevent.jaxtype_to_warptype#
- brainevent.jaxtype_to_warptype(dtype)[source]#
Convert a JAX / NumPy dtype to the corresponding Warp scalar type.
Maps standard NumPy data types (which are also used by JAX) to their Warp equivalents. This is needed when constructing Warp kernel signatures or Warp array types from JAX metadata.
- Parameters:
dtype (numpy.dtype or type) – The data type to convert. Accepts any object that can be compared with NumPy scalar types (e.g.,
np.float32,jnp.float32,np.dtype('float32')).- Returns:
The corresponding Warp scalar type (e.g.,
warp.float32,warp.int32,warp.bool).- Return type:
warp type
- Raises:
ImportError – If the
warppackage is not installed.ValueError – If dtype does not correspond to any supported Warp type. Supported types include:
float16,float32,float64,int8,int16,int32,int64,uint8,uint16,uint32,uint64, andbool_.
See also
jaxinfo_to_warpinfoConvert a full
jax.ShapeDtypeStructto a Warp array type.check_warp_installedVerify that Warp is available.
Notes
The mapping covers all scalar types supported by both NumPy and Warp:
float16,float32,float64,int8throughint64,uint8throughuint64, andbool_. Complex types are not supported by Warp and will raiseValueError.Examples
>>> import numpy as np >>> warp_type = jaxtype_to_warptype(np.float32) >>> warp_type # warp.float32