cross_correlation#
- class braintools.metric.cross_correlation(spikes, bin, dt=None, method='loop')#
Calculate cross-correlation index between neurons.
The coherence between two neurons i and j is measured by their cross-correlation of spike trains at zero time lag within a time bin. This function computes the population synchronization index based on pairwise cross-correlations.
The coherence measure for a pair is defined as:
\[\kappa_{ij}(\tau) = \frac{\sum_{l=1}^{K} X(l) Y(l)} {\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}\]where the time interval is divided into K bins of size \(\Delta t = \tau\), and \(X(l)\), \(Y(l)\) are binary spike indicators (0 or 1) for each bin.
The population coherence measure \(\kappa(\tau)\) is the average of \(\kappa_{ij}(\tau)\) over all pairs of neurons.
- Parameters:
spikes (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Spike history matrix with shape(num_time, num_neurons). Binary values indicating spike occurrences.dt (
int|float) – Time precision. If None, usesbrainstate.environ.get_dt().method (
str) –Method for computing cross-correlations:
'loop': Memory-efficient iterative approach'vmap': Vectorized approach (uses more memory)
- Returns:
Cross-correlation index representing the population synchronization level. Values closer to 1 indicate higher synchronization.
- Return type:
Notes
To JIT compile this function, make
bin,dt, andmethodstatic. For example:partial(cross_correlation, bin=10, method='loop').Examples
>>> import jax.numpy as jnp >>> import braintools as braintools >>> # Generate random spike data >>> spikes = jnp.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]]) >>> sync_index = braintools.metric.cross_correlation(spikes, bin=1.0) >>> print(f"Synchronization index: {sync_index:.3f}") >>> >>> # For larger datasets, use vectorized method >>> large_spikes = jnp.random.binomial(1, 0.1, (1000, 50)) >>> sync_fast = braintools.metric.cross_correlation(large_spikes, bin=10.0, method='vmap') >>> print(f"Population synchronization: {sync_fast:.3f}")
References