cross_correlation

cross_correlation#

class braintools.metric.cross_correlation(spikes, bin, dt=None, method='loop')#

Calculate cross-correlation index between neurons.

The coherence between two neurons i and j is measured by their cross-correlation of spike trains at zero time lag within a time bin. This function computes the population synchronization index based on pairwise cross-correlations.

The coherence measure for a pair is defined as:

\[\kappa_{ij}(\tau) = \frac{\sum_{l=1}^{K} X(l) Y(l)} {\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}\]

where the time interval is divided into K bins of size \(\Delta t = \tau\), and \(X(l)\), \(Y(l)\) are binary spike indicators (0 or 1) for each bin.

The population coherence measure \(\kappa(\tau)\) is the average of \(\kappa_{ij}(\tau)\) over all pairs of neurons.

Parameters:
  • spikes (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Spike history matrix with shape (num_time, num_neurons). Binary values indicating spike occurrences.

  • bin (int | float) – Time bin size for binning spike trains.

  • dt (int | float) – Time precision. If None, uses brainstate.environ.get_dt().

  • method (str) –

    Method for computing cross-correlations:

    • 'loop': Memory-efficient iterative approach

    • 'vmap': Vectorized approach (uses more memory)

Returns:

Cross-correlation index representing the population synchronization level. Values closer to 1 indicate higher synchronization.

Return type:

float

Notes

To JIT compile this function, make bin, dt, and method static. For example: partial(cross_correlation, bin=10, method='loop').

Examples

>>> import jax.numpy as jnp
>>> import braintools as braintools
>>> # Generate random spike data
>>> spikes = jnp.array([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
>>> sync_index = braintools.metric.cross_correlation(spikes, bin=1.0)
>>> print(f"Synchronization index: {sync_index:.3f}")
>>>
>>> # For larger datasets, use vectorized method
>>> large_spikes = jnp.random.binomial(1, 0.1, (1000, 50))
>>> sync_fast = braintools.metric.cross_correlation(large_spikes, bin=10.0, method='vmap')
>>> print(f"Population synchronization: {sync_fast:.3f}")

References