brainevent.jaxtype_to_warptype

brainevent.jaxtype_to_warptype#

brainevent.jaxtype_to_warptype(dtype)[source]#

Convert a JAX / NumPy dtype to the corresponding Warp scalar type.

Maps standard NumPy data types (which are also used by JAX) to their Warp equivalents. This is needed when constructing Warp kernel signatures or Warp array types from JAX metadata.

Parameters:

dtype (numpy.dtype or type) – The data type to convert. Accepts any object that can be compared with NumPy scalar types (e.g., np.float32, jnp.float32, np.dtype('float32')).

Returns:

The corresponding Warp scalar type (e.g., warp.float32, warp.int32, warp.bool).

Return type:

warp type

Raises:
  • ImportError – If the warp package is not installed.

  • ValueError – If dtype does not correspond to any supported Warp type. Supported types include: float16, float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64, and bool_.

See also

jaxinfo_to_warpinfo

Convert a full jax.ShapeDtypeStruct to a Warp array type.

check_warp_installed

Verify that Warp is available.

Notes

The mapping covers all scalar types supported by both NumPy and Warp: float16, float32, float64, int8 through int64, uint8 through uint64, and bool_. Complex types are not supported by Warp and will raise ValueError.

Examples

>>> import numpy as np
>>> warp_type = jaxtype_to_warptype(np.float32)
>>> warp_type  # warp.float32