brainstate.random module#
Random number generation module for BrainState.
This module provides a comprehensive set of random number generation functions and utilities for neural network simulations and scientific computing. It wraps JAX’s random number generation capabilities with a stateful interface that simplifies usage while maintaining reproducibility and performance.
The module includes:
Standard random distributions (uniform, normal, exponential, etc.)
Random state management with automatic key splitting
Seed management utilities for reproducible simulations
NumPy-compatible API for easy migration
Key Features#
Stateful random generation: Automatic management of JAX’s PRNG keys
NumPy compatibility: Drop-in replacement for most NumPy random functions
Reproducibility: Robust seed management and state tracking
Performance: JIT-compiled random functions for efficient generation
Thread-safe: Proper handling of random state in parallel computations
Random State Management#
The module uses a global DEFAULT RandomState instance that automatically manages JAX’s PRNG keys. This eliminates the need to manually track and split keys:
>>> import brainstate as bs
>>> import brainstate.random as bsr
>>>
>>> # Set a global seed for reproducibility
>>> bsr.seed(42)
>>>
>>> # Generate random numbers without manual key management
>>> x = bsr.normal(0, 1, size=(3, 3))
>>> y = bsr.uniform(0, 1, size=(100,))
Custom Random States#
For more control, you can create custom RandomState instances:
>>> import brainstate.random as bsr
>>>
>>> # Create a custom random state
>>> rng = bsr.RandomState(seed=123)
>>>
>>> # Use it for generation
>>> data = rng.normal(0, 1, size=(10, 10))
>>>
>>> # Get the current key
>>> current_key = rng.value
Available Distributions#
The module provides a wide range of probability distributions:
Uniform Distributions:
rand, random, random_sample, ranf, sample - Uniform [0, 1)
randint, random_integers - Uniform integers
choice - Random selection from array
permutation, shuffle - Random ordering
Normal Distributions:
randn, normal - Normal (Gaussian) distribution
standard_normal - Standard normal distribution
multivariate_normal - Multivariate normal distribution
truncated_normal - Truncated normal distribution
Other Continuous Distributions:
beta - Beta distribution
exponential, standard_exponential - Exponential distribution
gamma, standard_gamma - Gamma distribution
gumbel - Gumbel distribution
laplace - Laplace distribution
logistic - Logistic distribution
pareto - Pareto distribution
rayleigh - Rayleigh distribution
standard_cauchy - Cauchy distribution
standard_t - Student’s t-distribution
uniform - Uniform distribution over [low, high)
weibull - Weibull distribution
Discrete Distributions:
bernoulli - Bernoulli distribution
binomial - Binomial distribution
poisson - Poisson distribution
Seed Management#
The module provides utilities for managing random seeds:
>>> import brainstate.random as bsr
>>>
>>> # Set a global seed
>>> bsr.seed(42)
>>>
>>> # Get current seed/key
>>> key = bsr.get_key()
>>>
>>> # Split the key for parallel operations
>>> keys = bsr.split_key(n=4)
>>>
>>> # Use context manager for temporary seed
>>> with bsr.local_seed(123):
... x = bsr.normal(0, 1, (5,)) # Uses seed 123
>>> y = bsr.normal(0, 1, (5,)) # Uses original seed
Examples
Basic random number generation:
>>> import brainstate.random as bsr
>>> import jax.numpy as jnp
>>>
>>> # Set seed for reproducibility
>>> bsr.seed(0)
>>>
>>> # Generate uniform random numbers
>>> uniform_data = bsr.random((3, 3))
>>> print(uniform_data.shape)
(3, 3)
>>>
>>> # Generate normal random numbers
>>> normal_data = bsr.normal(loc=0, scale=1, size=(100,))
>>> print(f"Mean: {normal_data.mean():.3f}, Std: {normal_data.std():.3f}")
Mean: -0.045, Std: 0.972
Sampling and shuffling:
>>> import brainstate.random as bsr
>>> import jax.numpy as jnp
>>>
>>> bsr.seed(42)
>>>
>>> # Random choice from array
>>> arr = jnp.array([1, 2, 3, 4, 5])
>>> samples = bsr.choice(arr, size=3, replace=False)
>>> print(samples)
[4 1 5]
>>>
>>> # Random permutation
>>> perm = bsr.permutation(10)
>>> print(perm)
[3 5 1 7 9 0 2 8 4 6]
>>>
>>> # In-place shuffle
>>> data = jnp.arange(5)
>>> bsr.shuffle(data)
>>> print(data)
[2 0 4 1 3]
Advanced distributions:
>>> import brainstate.random as bsr
>>> import matplotlib.pyplot as plt
>>>
>>> bsr.seed(123)
>>>
>>> # Generate samples from different distributions
>>> normal_samples = bsr.normal(0, 1, 1000)
>>> exponential_samples = bsr.exponential(1.0, 1000)
>>> beta_samples = bsr.beta(2, 5, 1000)
>>>
>>> # Plot histograms
>>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
>>> axes[0].hist(normal_samples, bins=30, density=True)
>>> axes[0].set_title('Normal Distribution')
>>> axes[1].hist(exponential_samples, bins=30, density=True)
>>> axes[1].set_title('Exponential Distribution')
>>> axes[2].hist(beta_samples, bins=30, density=True)
>>> axes[2].set_title('Beta Distribution')
>>> plt.show()
Using with neural network simulations:
>>> import brainstate as bs
>>> import brainstate.random as bsr
>>> import brainstate.nn as nn
>>>
>>> class NoisyNeuron(bs.Module):
... def __init__(self, n_neurons, noise_scale=0.1):
... super().__init__()
... self.n_neurons = n_neurons
... self.noise_scale = noise_scale
... self.membrane = bs.State(jnp.zeros(n_neurons))
...
... def update(self, input_current):
... # Add noise to input current
... noise = bsr.normal(0, self.noise_scale, self.n_neurons)
... self.membrane.value += input_current + noise
... return self.membrane.value
>>>
>>> # Create and run noisy neuron model
>>> bsr.seed(42)
>>> neuron = NoisyNeuron(100)
>>> output = neuron.update(jnp.ones(100) * 0.5)
Notes
This module is designed to work seamlessly with JAX’s functional programming model
Random functions are JIT-compilable for optimal performance
The global DEFAULT state is thread-local to avoid race conditions
For deterministic results, always set a seed before random operations
See also
jax.randomJAX’s random number generation module
numpy.randomNumPy’s random number generation module
RandomStateThe stateful random number generator class
References
Random State Management#
Core components for managing random number generator state and ensuring reproducible computations.
RandomState that track the random generator state. |
Seed and Key Management#
Functions for controlling the global random state and creating independent random number generators.
Set the global random seed for both JAX and NumPy. |
|
Context manager for temporary random seed changes with automatic restoration. |
|
Get the default random state or create a new one with specified seed. |
|
Create a clone of the random state or a new random state. |
|
Set a new random key for the global random state. |
|
Get the current global random key. |
|
Restore the default random key to its previous state. |
Key Splitting and Parallel Generation#
Functions for creating independent random keys for parallel computation.
Create new random key(s) from the current seed. |
|
Create multiple independent random keys from the current seed. |
|
Assign multiple keys to the global random state for parallel access. |
Random Sampling Functions#
Comprehensive collection of probability distributions and sampling functions, providing NumPy-compatible interfaces with JAX backend acceleration.
Basic Random Sampling#
Fundamental random number generation functions for common use cases.
Random values in a given shape. |
|
Return a sample (or samples) from the "standard normal" distribution. |
|
Return random floats in the half-open interval [0.0, 1.0). |
|
Return random floats in the half-open interval [0.0, 1.0). |
|
This is an alias of random_sample. |
|
This is an alias of random_sample. |
|
Return random integers from low (inclusive) to high (exclusive). |
|
Random integers of type np.int_ between low and high, inclusive. |
Array-like Generation (PyTorch compatibility)#
Functions that generate random arrays with shapes matching existing arrays.
Similar to |
|
Similar to |
|
Similar to |
Array Manipulation#
Functions for random permutations and selections.
Generates a random sample from a given 1-D array |
|
Randomly permute a sequence, or return a permuted range. |
|
Modify a sequence in-place by shuffling its contents. |
Continuous Distributions#
Probability distributions for continuous random variables.
Draw samples from a Beta distribution. |
|
Draw samples from an exponential distribution. |
|
Draw samples from a Gamma distribution. |
|
Draw samples from a Gumbel distribution. |
|
Draw samples from the Laplace or double exponential distribution with specified location (or mean) and scale (decay). |
|
Draw samples from a logistic distribution. |
|
Draw random samples from a normal (Gaussian) distribution. |
|
Draw samples from a Pareto II or Lomax distribution with specified shape. |
|
Draw samples from a standard Cauchy distribution with mode = 0. |
|
Draw samples from the standard exponential distribution. |
|
Draw samples from a standard Gamma distribution. |
|
Draw samples from a standard Normal distribution (mean=0, stdev=1). |
|
Draw samples from a standard Student's t distribution with df degrees of freedom. |
|
Draw samples from a uniform distribution. |
|
Sample truncated standard normal random values with given shape and dtype. |
|
Draw samples from a log-normal distribution. |
|
Draws samples in [0, 1] from a power distribution with positive exponent a - 1. |
|
Draw samples from a Rayleigh distribution. |
|
Draw samples from the triangular distribution over the interval |
|
Draw samples from a von Mises distribution. |
|
Draw samples from a Wald, or inverse Gaussian, distribution. |
|
Draw samples from a Weibull distribution. |
|
Sample from a Weibull distribution. |
|
Sample from a one sided Maxwell distribution. |
|
Sample Student’s t random values. |
|
Sample log-gamma random values. |
Discrete Distributions#
Probability distributions for discrete random variables.
Sample Bernoulli random values with given shape and mean. |
|
Draw samples from a binomial distribution. |
|
Sample random values from categorical distributions. |
|
Draw samples from the geometric distribution. |
|
Draw samples from a Hypergeometric distribution. |
|
Draw samples from a logarithmic series distribution. |
|
Draw samples from a multinomial distribution. |
|
Draw samples from a negative binomial distribution. |
|
Draw samples from a Poisson distribution. |
|
Draw samples from a Zipf distribution. |
Special Distributions#
Specialized distributions for statistical and scientific applications.
Draw samples from a chi-square distribution. |
|
Draw samples from the Dirichlet distribution. |
|
Draw samples from an F distribution. |
|
Draw random samples from a multivariate normal distribution. |
|
Draw samples from a noncentral chi-square distribution. |
|
Draw samples from the noncentral F distribution. |
|
Sample uniformly from the orthogonal group O(n). |