brainstate.random module#

Random number generation module for BrainState.

This module provides a comprehensive set of random number generation functions and utilities for neural network simulations and scientific computing. It wraps JAX’s random number generation capabilities with a stateful interface that simplifies usage while maintaining reproducibility and performance.

The module includes:

  • Standard random distributions (uniform, normal, exponential, etc.)

  • Random state management with automatic key splitting

  • Seed management utilities for reproducible simulations

  • NumPy-compatible API for easy migration

Key Features#

  • Stateful random generation: Automatic management of JAX’s PRNG keys

  • NumPy compatibility: Drop-in replacement for most NumPy random functions

  • Reproducibility: Robust seed management and state tracking

  • Performance: JIT-compiled random functions for efficient generation

  • Thread-safe: Proper handling of random state in parallel computations

Random State Management#

The module uses a global DEFAULT RandomState instance that automatically manages JAX’s PRNG keys. This eliminates the need to manually track and split keys:

>>> import brainstate as bs
>>> import brainstate.random as bsr
>>>
>>> # Set a global seed for reproducibility
>>> bsr.seed(42)
>>>
>>> # Generate random numbers without manual key management
>>> x = bsr.normal(0, 1, size=(3, 3))
>>> y = bsr.uniform(0, 1, size=(100,))

Custom Random States#

For more control, you can create custom RandomState instances:

>>> import brainstate.random as bsr
>>>
>>> # Create a custom random state
>>> rng = bsr.RandomState(seed=123)
>>>
>>> # Use it for generation
>>> data = rng.normal(0, 1, size=(10, 10))
>>>
>>> # Get the current key
>>> current_key = rng.value

Available Distributions#

The module provides a wide range of probability distributions:

Uniform Distributions:

  • rand, random, random_sample, ranf, sample - Uniform [0, 1)

  • randint, random_integers - Uniform integers

  • choice - Random selection from array

  • permutation, shuffle - Random ordering

Normal Distributions:

  • randn, normal - Normal (Gaussian) distribution

  • standard_normal - Standard normal distribution

  • multivariate_normal - Multivariate normal distribution

  • truncated_normal - Truncated normal distribution

Other Continuous Distributions:

  • beta - Beta distribution

  • exponential, standard_exponential - Exponential distribution

  • gamma, standard_gamma - Gamma distribution

  • gumbel - Gumbel distribution

  • laplace - Laplace distribution

  • logistic - Logistic distribution

  • pareto - Pareto distribution

  • rayleigh - Rayleigh distribution

  • standard_cauchy - Cauchy distribution

  • standard_t - Student’s t-distribution

  • uniform - Uniform distribution over [low, high)

  • weibull - Weibull distribution

Discrete Distributions:

  • bernoulli - Bernoulli distribution

  • binomial - Binomial distribution

  • poisson - Poisson distribution

Seed Management#

The module provides utilities for managing random seeds:

>>> import brainstate.random as bsr
>>>
>>> # Set a global seed
>>> bsr.seed(42)
>>>
>>> # Get current seed/key
>>> key = bsr.get_key()
>>>
>>> # Split the key for parallel operations
>>> keys = bsr.split_key(n=4)
>>>
>>> # Use context manager for temporary seed
>>> with bsr.local_seed(123):
...     x = bsr.normal(0, 1, (5,))  # Uses seed 123
>>> y = bsr.normal(0, 1, (5,))  # Uses original seed

Examples

Basic random number generation:

>>> import brainstate.random as bsr
>>> import jax.numpy as jnp
>>>
>>> # Set seed for reproducibility
>>> bsr.seed(0)
>>>
>>> # Generate uniform random numbers
>>> uniform_data = bsr.random((3, 3))
>>> print(uniform_data.shape)
(3, 3)
>>>
>>> # Generate normal random numbers
>>> normal_data = bsr.normal(loc=0, scale=1, size=(100,))
>>> print(f"Mean: {normal_data.mean():.3f}, Std: {normal_data.std():.3f}")
Mean: -0.045, Std: 0.972

Sampling and shuffling:

>>> import brainstate.random as bsr
>>> import jax.numpy as jnp
>>>
>>> bsr.seed(42)
>>>
>>> # Random choice from array
>>> arr = jnp.array([1, 2, 3, 4, 5])
>>> samples = bsr.choice(arr, size=3, replace=False)
>>> print(samples)
[4 1 5]
>>>
>>> # Random permutation
>>> perm = bsr.permutation(10)
>>> print(perm)
[3 5 1 7 9 0 2 8 4 6]
>>>
>>> # In-place shuffle
>>> data = jnp.arange(5)
>>> bsr.shuffle(data)
>>> print(data)
[2 0 4 1 3]

Advanced distributions:

>>> import brainstate.random as bsr
>>> import matplotlib.pyplot as plt
>>>
>>> bsr.seed(123)
>>>
>>> # Generate samples from different distributions
>>> normal_samples = bsr.normal(0, 1, 1000)
>>> exponential_samples = bsr.exponential(1.0, 1000)
>>> beta_samples = bsr.beta(2, 5, 1000)
>>>
>>> # Plot histograms
>>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
>>> axes[0].hist(normal_samples, bins=30, density=True)
>>> axes[0].set_title('Normal Distribution')
>>> axes[1].hist(exponential_samples, bins=30, density=True)
>>> axes[1].set_title('Exponential Distribution')
>>> axes[2].hist(beta_samples, bins=30, density=True)
>>> axes[2].set_title('Beta Distribution')
>>> plt.show()

Using with neural network simulations:

>>> import brainstate as bs
>>> import brainstate.random as bsr
>>> import brainstate.nn as nn
>>>
>>> class NoisyNeuron(bs.Module):
...     def __init__(self, n_neurons, noise_scale=0.1):
...         super().__init__()
...         self.n_neurons = n_neurons
...         self.noise_scale = noise_scale
...         self.membrane = bs.State(jnp.zeros(n_neurons))
...
...     def update(self, input_current):
...         # Add noise to input current
...         noise = bsr.normal(0, self.noise_scale, self.n_neurons)
...         self.membrane.value += input_current + noise
...         return self.membrane.value
>>>
>>> # Create and run noisy neuron model
>>> bsr.seed(42)
>>> neuron = NoisyNeuron(100)
>>> output = neuron.update(jnp.ones(100) * 0.5)

Notes

  • This module is designed to work seamlessly with JAX’s functional programming model

  • Random functions are JIT-compilable for optimal performance

  • The global DEFAULT state is thread-local to avoid race conditions

  • For deterministic results, always set a seed before random operations

See also

jax.random

JAX’s random number generation module

numpy.random

NumPy’s random number generation module

RandomState

The stateful random number generator class

References

Random State Management#

Core components for managing random number generator state and ensuring reproducible computations.

RandomState

RandomState that track the random generator state.

Seed and Key Management#

Functions for controlling the global random state and creating independent random number generators.

seed

Set the global random seed for both JAX and NumPy.

seed_context

Context manager for temporary random seed changes with automatic restoration.

default_rng

Get the default random state or create a new one with specified seed.

clone_rng

Create a clone of the random state or a new random state.

set_key

Set a new random key for the global random state.

get_key

Get the current global random key.

restore_key

Restore the default random key to its previous state.

Key Splitting and Parallel Generation#

Functions for creating independent random keys for parallel computation.

split_key

Create new random key(s) from the current seed.

split_keys

Create multiple independent random keys from the current seed.

self_assign_multi_keys

Assign multiple keys to the global random state for parallel access.

Random Sampling Functions#

Comprehensive collection of probability distributions and sampling functions, providing NumPy-compatible interfaces with JAX backend acceleration.

Basic Random Sampling#

Fundamental random number generation functions for common use cases.

rand

Random values in a given shape.

randn

Return a sample (or samples) from the "standard normal" distribution.

random

Return random floats in the half-open interval [0.0, 1.0).

random_sample

Return random floats in the half-open interval [0.0, 1.0).

ranf

This is an alias of random_sample.

sample

This is an alias of random_sample.

randint

Return random integers from low (inclusive) to high (exclusive).

random_integers

Random integers of type np.int_ between low and high, inclusive.

Array-like Generation (PyTorch compatibility)#

Functions that generate random arrays with shapes matching existing arrays.

rand_like

Similar to rand_like in torch.

randint_like

Similar to randint_like in torch.

randn_like

Similar to randn_like in torch.

Array Manipulation#

Functions for random permutations and selections.

choice

Generates a random sample from a given 1-D array

permutation

Randomly permute a sequence, or return a permuted range.

shuffle

Modify a sequence in-place by shuffling its contents.

Continuous Distributions#

Probability distributions for continuous random variables.

beta

Draw samples from a Beta distribution.

exponential

Draw samples from an exponential distribution.

gamma

Draw samples from a Gamma distribution.

gumbel

Draw samples from a Gumbel distribution.

laplace

Draw samples from the Laplace or double exponential distribution with specified location (or mean) and scale (decay).

logistic

Draw samples from a logistic distribution.

normal

Draw random samples from a normal (Gaussian) distribution.

pareto

Draw samples from a Pareto II or Lomax distribution with specified shape.

standard_cauchy

Draw samples from a standard Cauchy distribution with mode = 0.

standard_exponential

Draw samples from the standard exponential distribution.

standard_gamma

Draw samples from a standard Gamma distribution.

standard_normal

Draw samples from a standard Normal distribution (mean=0, stdev=1).

standard_t

Draw samples from a standard Student's t distribution with df degrees of freedom.

uniform

Draw samples from a uniform distribution.

truncated_normal

Sample truncated standard normal random values with given shape and dtype.

lognormal

Draw samples from a log-normal distribution.

power

Draws samples in [0, 1] from a power distribution with positive exponent a - 1.

rayleigh

Draw samples from a Rayleigh distribution.

triangular

Draw samples from the triangular distribution over the interval [left, right].

vonmises

Draw samples from a von Mises distribution.

wald

Draw samples from a Wald, or inverse Gaussian, distribution.

weibull

Draw samples from a Weibull distribution.

weibull_min

Sample from a Weibull distribution.

maxwell

Sample from a one sided Maxwell distribution.

t

Sample Student’s t random values.

loggamma

Sample log-gamma random values.

Discrete Distributions#

Probability distributions for discrete random variables.

bernoulli

Sample Bernoulli random values with given shape and mean.

binomial

Draw samples from a binomial distribution.

categorical

Sample random values from categorical distributions.

geometric

Draw samples from the geometric distribution.

hypergeometric

Draw samples from a Hypergeometric distribution.

logseries

Draw samples from a logarithmic series distribution.

multinomial

Draw samples from a multinomial distribution.

negative_binomial

Draw samples from a negative binomial distribution.

poisson

Draw samples from a Poisson distribution.

zipf

Draw samples from a Zipf distribution.

Special Distributions#

Specialized distributions for statistical and scientific applications.

chisquare

Draw samples from a chi-square distribution.

dirichlet

Draw samples from the Dirichlet distribution.

f

Draw samples from an F distribution.

multivariate_normal

Draw random samples from a multivariate normal distribution.

noncentral_chisquare

Draw samples from a noncentral chi-square distribution.

noncentral_f

Draw samples from the noncentral F distribution.

orthogonal

Sample uniformly from the orthogonal group O(n).