set_host_device_count#
- class brainstate.environ.set_host_device_count(n)[source]#
Set the number of host (CPU) devices.
This function configures XLA to treat CPU cores as separate devices, enabling parallel computation with jax.pmap on CPU.
- Parameters:
n (
int) – Number of host devices to configure.- Raises:
ValueError – If n is not a positive integer.
- Return type:
Examples
>>> import brainstate.environ as env >>> import jax >>> >>> # Configure 4 CPU devices >>> env.set_host_device_count(4) >>> >>> # Use with pmap >>> def parallel_fn(x): ... return x * 2 >>> >>> # This will work with 4 devices >>> pmapped_fn = jax.pmap(parallel_fn)
Warning
This setting only takes effect at program start. The effects of using xla_force_host_platform_device_count are not fully understood and may cause unexpected behavior.