set_host_device_count

set_host_device_count#

class brainstate.environ.set_host_device_count(n)[source]#

Set the number of host (CPU) devices.

This function configures XLA to treat CPU cores as separate devices, enabling parallel computation with jax.pmap on CPU.

Parameters:

n (int) – Number of host devices to configure.

Raises:

ValueError – If n is not a positive integer.

Return type:

None

Examples

>>> import brainstate.environ as env
>>> import jax
>>>
>>> # Configure 4 CPU devices
>>> env.set_host_device_count(4)
>>>
>>> # Use with pmap
>>> def parallel_fn(x):
...     return x * 2
>>>
>>> # This will work with 4 devices
>>> pmapped_fn = jax.pmap(parallel_fn)

Warning

This setting only takes effect at program start. The effects of using xla_force_host_platform_device_count are not fully understood and may cause unexpected behavior.