braintools.metric module#

Metrics and Loss Functions for Neural Networks and Neuroscience.

This module provides comprehensive metrics and loss functions for both machine learning and neuroscience applications, including classification, regression, ranking, spike train analysis, and local field potential (LFP) analysis.

Key Features:

  • Classification Losses: Binary and multi-class cross-entropy, hinge loss, focal loss

  • Regression Losses: MSE, MAE, Huber loss, cosine similarity

  • Ranking Losses: Softmax ranking loss for learning to rank

  • Spike Train Metrics: Firing rate, synchrony, distance measures

  • LFP Analysis: Power spectral density, coherence, phase-amplitude coupling

  • Correlation Analysis: Cross-correlation, functional connectivity

  • Pairwise Metrics: Cosine similarity for pairwise comparisons

Quick Start - Classification:

import jax.numpy as jnp
from braintools.metric import softmax_cross_entropy, sigmoid_focal_loss

# Multi-class classification
logits = jnp.array([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
loss = softmax_cross_entropy(logits, labels)

# Binary classification with focal loss
predictions = jnp.array([0.9, 0.1, 0.7])
targets = jnp.array([1.0, 0.0, 1.0])
focal_loss = sigmoid_focal_loss(predictions, targets)

Quick Start - Regression:

import jax.numpy as jnp
from braintools.metric import squared_error, huber_loss

predictions = jnp.array([1.5, 2.3, 3.1])
targets = jnp.array([1.0, 2.0, 3.0])

# Mean squared error
mse = squared_error(predictions, targets)

# Huber loss (robust to outliers)
huber = huber_loss(predictions, targets, delta=1.0)

Quick Start - Spike Train Analysis:

import brainunit as u
import jax.numpy as jnp
from braintools.metric import (
    firing_rate, victor_purpura_distance,
    spike_train_synchrony
)

# Calculate firing rate
spike_times = jnp.array([0.1, 0.3, 0.5, 0.7]) * u.second
rate = firing_rate(spike_times, duration=1.0 * u.second)

# Victor-Purpura distance between spike trains
train1 = jnp.array([0.1, 0.3, 0.5]) * u.second
train2 = jnp.array([0.12, 0.31, 0.52]) * u.second
distance = victor_purpura_distance(train1, train2, cost=1.0)

# Spike train synchrony
spike_matrix = jnp.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0]])
synchrony = spike_train_synchrony(spike_matrix)

Classification Losses:

import jax.numpy as jnp
from braintools.metric import (
    sigmoid_binary_cross_entropy,
    softmax_cross_entropy_with_integer_labels,
    hinge_loss,
    multiclass_hinge_loss,
    kl_divergence,
    sigmoid_focal_loss
)

# Binary cross-entropy
logits = jnp.array([2.0, -1.0, 0.5])
labels = jnp.array([1.0, 0.0, 1.0])
bce = sigmoid_binary_cross_entropy(logits, labels)

# Multi-class with integer labels
logits = jnp.array([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
labels = jnp.array([0, 1])  # Class indices
ce = softmax_cross_entropy_with_integer_labels(logits, labels)

# Hinge loss for SVM-style classification
predictions = jnp.array([0.9, -0.5, 0.3])
targets = jnp.array([1.0, -1.0, 1.0])
hinge = hinge_loss(predictions, targets)

# KL divergence
p = jnp.array([0.5, 0.3, 0.2])
q = jnp.array([0.4, 0.4, 0.2])
kl = kl_divergence(p, q)

# Focal loss for imbalanced datasets
predictions = jnp.array([0.9, 0.1, 0.6])
targets = jnp.array([1.0, 0.0, 1.0])
focal = sigmoid_focal_loss(predictions, targets, alpha=0.25, gamma=2.0)

Regression Losses:

import jax.numpy as jnp
from braintools.metric import (
    squared_error,
    absolute_error,
    l1_loss,
    l2_loss,
    huber_loss,
    log_cosh,
    cosine_similarity,
    cosine_distance
)

predictions = jnp.array([1.5, 2.3, 3.1, 4.2])
targets = jnp.array([1.0, 2.0, 3.0, 4.0])

# Various regression losses
mse = squared_error(predictions, targets)
mae = absolute_error(predictions, targets)
l1 = l1_loss(predictions, targets)
l2 = l2_loss(predictions, targets)
huber = huber_loss(predictions, targets, delta=1.0)
log_cosh_loss = log_cosh(predictions, targets)

# Cosine similarity/distance
x = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = jnp.array([[1.0, 2.0, 3.0], [7.0, 8.0, 9.0]])
similarity = cosine_similarity(x, y)
distance = cosine_distance(x, y)

Spike Train Analysis:

import brainunit as u
import jax.numpy as jnp
from braintools.metric import (
    raster_plot,
    firing_rate,
    victor_purpura_distance,
    van_rossum_distance,
    spike_train_synchrony,
    burst_synchrony_index,
    phase_locking_value,
    spike_time_tiling_coefficient,
    correlation_index
)

# Raster plot data extraction
spike_matrix = jnp.array([[1, 0, 1, 0], [0, 1, 1, 0]])
times = jnp.arange(4) * 0.1 * u.second
neuron_ids, spike_times = raster_plot(spike_matrix, times)

# Firing rate calculation
spikes = jnp.array([1, 0, 1, 1, 0, 0, 1, 0])
rate = firing_rate(spikes, window_size=100 * u.ms, dt=10 * u.ms)

# Distance metrics between spike trains
train1 = jnp.array([0.1, 0.3, 0.5]) * u.second
train2 = jnp.array([0.12, 0.31, 0.52]) * u.second
vp_dist = victor_purpura_distance(train1, train2, cost=1.0)
vr_dist = van_rossum_distance(train1, train2, tau=0.01 * u.second)

# Synchrony measures
spike_matrix = jnp.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0]])
synchrony = spike_train_synchrony(spike_matrix)
burst_sync = burst_synchrony_index(spike_matrix)
plv = phase_locking_value(spike_matrix)
sttc = spike_time_tiling_coefficient(train1, train2, dt=0.05 * u.second)
corr_idx = correlation_index(spike_matrix)

LFP Analysis:

import brainunit as u
import jax.numpy as jnp
from braintools.metric import (
    unitary_LFP,
    power_spectral_density,
    coherence_analysis,
    phase_amplitude_coupling,
    theta_gamma_coupling,
    current_source_density,
    spectral_entropy,
    lfp_phase_coherence
)

# Unitary LFP from spike trains
times = jnp.arange(1000) * 0.001 * u.second
spikes = jnp.random.randint(0, 2, (100, 1000))
ulfp = unitary_LFP(times, spikes, spike_type='excitatory')

# Power spectral density
lfp_signal = jnp.sin(2 * jnp.pi * 10 * times.mantissa)  # 10 Hz signal
freqs, psd = power_spectral_density(lfp_signal, fs=1000 * u.Hz)

# Coherence analysis
signal1 = jnp.sin(2 * jnp.pi * 10 * times.mantissa)
signal2 = jnp.sin(2 * jnp.pi * 10 * times.mantissa + 0.1)
freqs, coherence = coherence_analysis(signal1, signal2, fs=1000 * u.Hz)

# Phase-amplitude coupling
pac = phase_amplitude_coupling(
    lfp_signal,
    phase_freq=(4, 8),  # Theta band
    amp_freq=(30, 80),  # Gamma band
    fs=1000 * u.Hz
)

# Theta-gamma coupling
tgc = theta_gamma_coupling(lfp_signal, fs=1000 * u.Hz)

# Current source density
lfp_channels = jnp.random.randn(16, 1000)  # 16 channels
csd = current_source_density(lfp_channels, spacing=100 * u.um)

# Spectral entropy
entropy = spectral_entropy(lfp_signal, fs=1000 * u.Hz)

# Phase coherence
phase_coh = lfp_phase_coherence(signal1, signal2, freq_band=(8, 12))

Correlation Analysis:

import jax.numpy as jnp
from braintools.metric import (
    cross_correlation,
    voltage_fluctuation,
    matrix_correlation,
    weighted_correlation,
    functional_connectivity,
    functional_connectivity_dynamics
)

# Cross-correlation between spike trains
spikes = jnp.array([[1, 0, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0]])
cc = cross_correlation(spikes, bin=10, dt=1)

# Voltage fluctuation correlation
voltages = jnp.random.randn(100, 1000)  # 100 neurons, 1000 time points
vf_corr = voltage_fluctuation(voltages)

# Correlation matrix
data = jnp.random.randn(50, 100)  # 50 samples, 100 features
corr_matrix = matrix_correlation(data)

# Weighted correlation
x = jnp.array([1, 2, 3, 4, 5])
y = jnp.array([2, 4, 5, 4, 5])
weights = jnp.array([1, 1, 2, 2, 1])
w_corr = weighted_correlation(x, y, weights)

# Functional connectivity
time_series = jnp.random.randn(10, 1000)  # 10 regions, 1000 time points
fc = functional_connectivity(time_series, method='pearson')

# Dynamic functional connectivity
fc_dynamics = functional_connectivity_dynamics(
    time_series,
    window_size=100,
    step_size=50
)

Advanced: Fenchel-Young Losses:

import jax.numpy as jnp
from braintools.metric import make_fenchel_young_loss

# Create custom loss from max function
def max_fun(scores):
    return jnp.max(scores, axis=-1, keepdims=True)

loss_fn = make_fenchel_young_loss(max_fun)
scores = jnp.array([2.0, 1.0, 3.0])
targets = jnp.array([1.0, 0.0, 0.0])
loss = loss_fn(scores, targets)

Ranking Losses:

import jax.numpy as jnp
from braintools.metric import ranking_softmax_loss

# Ranking loss for learning to rank
scores = jnp.array([[2.0, 1.0, 3.0], [1.0, 0.5, 1.5]])
labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
loss = ranking_softmax_loss(scores, labels)

Utilities:

import jax.numpy as jnp
from braintools.metric import smooth_labels

# Label smoothing for regularization
labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
smoothed = smooth_labels(labels, alpha=0.1)

Comprehensive metric collection covering spiking activity, statistical analysis, and supervised learning objectives for neural modeling.

Classification Losses#

Objective functions for training classifiers on neural or behavioral labels.

sigmoid_binary_cross_entropy

Compute element-wise sigmoid cross entropy given logits and labels.

hinge_loss

Compute the hinge loss for binary classification.

perceptron_loss

Compute the binary perceptron loss.

softmax_cross_entropy

Compute the softmax cross entropy between logits and labels.

softmax_cross_entropy_with_integer_labels

Compute softmax cross entropy between logits and integer labels.

multiclass_hinge_loss

Compute multiclass hinge loss for classification.

multiclass_perceptron_loss

Compute multiclass perceptron loss for classification.

poly_loss_cross_entropy

Compute PolyLoss cross entropy between logits and labels.

kl_divergence

Compute the Kullback-Leibler divergence (relative entropy) loss.

kl_divergence_with_log_targets

Compute KL divergence when both predictions and targets are in log-space.

convex_kl_divergence

Compute a convex version of the Kullback-Leibler divergence loss.

ctc_loss

Compute Connectionist Temporal Classification (CTC) loss.

ctc_loss_with_forward_probs

Compute CTC loss and forward probabilities for sequence alignment.

sigmoid_focal_loss

Compute sigmoid focal loss for addressing class imbalance.

nll_loss

Compute negative log likelihood loss for classification.

Correlation#

Tools for measuring synchrony, functional connectivity, and aggregated correlations between neural signals.

cross_correlation

Calculate cross-correlation index between neurons.

voltage_fluctuation

Calculate neuronal synchronization via voltage variance analysis.

matrix_correlation

Compute Pearson correlation of upper triangular elements of two matrices.

weighted_correlation

Compute weighted Pearson correlation between two data series.

functional_connectivity

Compute functional connectivity matrix from time series data.

functional_connectivity_dynamics

Compute functional connectivity dynamics (FCD) matrix.

Fenchel-Young Loss#

Generalized convex losses derived from Fenchel-Young duality for structured prediction problems.

make_fenchel_young_loss

Create a Fenchel-Young loss function from a max function.

Spike Firing#

Descriptive statistics that summarize firing rates, timing variability, and spiking reliability.

raster_plot

Extract spike times and neuron indices for raster plot visualization.

firing_rate

Calculate the smoothed population firing rate from spike data.

victor_purpura_distance

Calculate Victor-Purpura distance between two spike trains.

van_rossum_distance

Calculate van Rossum distance between two spike trains.

spike_train_synchrony

Calculate spike train synchrony using the SPIKE-synchronization measure.

burst_synchrony_index

Calculate burst synchrony index based on co-occurring burst events.

phase_locking_value

Calculate phase-locking value (PLV) for spike synchronization.

spike_time_tiling_coefficient

Calculate Spike Time Tiling Coefficient (STTC).

correlation_index

Calculate correlation index for spike train synchrony.

Local Field Potential#

Metrics tailored to local field potential (LFP) analysis such as spectral characteristics and connectivity.

unitary_LFP

Calculate unitary local field potentials (uLFP) from spike train data.

power_spectral_density

Compute power spectral density (PSD) of LFP signals using Welch's method.

coherence_analysis

Compute coherence between two LFP signals.

phase_amplitude_coupling

Compute phase-amplitude coupling (PAC) using the modulation index.

theta_gamma_coupling

Compute theta-gamma coupling strength using standard frequency bands.

current_source_density

Compute current source density (CSD) from laminar LFP recordings.

spectral_entropy

Compute spectral entropy of LFP signal as a complexity measure.

lfp_phase_coherence

Compute phase coherence between multiple LFP signals in a frequency band.

Ranking Losses#

Losses for ordered prediction tasks including pairwise and list-wise ranking setups.

ranking_softmax_loss

Compute ranking softmax loss for learning-to-rank applications.

Regression Losses#

Continuous-valued error metrics for fitting neural or behavioral signals.

squared_error

Compute element-wise squared error between predictions and targets.

absolute_error

Compute element-wise absolute error between predictions and targets.

l1_loss

Creates a criterion that measures the mean absolute error (MAE) between each element in the logits \(x\) and targets \(y\).

l2_loss

Calculates the L2 loss for a set of predictions.

l2_norm

Computes the L2 norm of the difference between predictions and targets.

huber_loss

Compute Huber loss combining L1 and L2 properties for robust regression.

log_cosh

Calculates the log-cosh loss for a set of predictions.

cosine_similarity

Compute cosine similarity between predicted and target vectors.

cosine_distance

Computes the cosine distance between targets and predictions.

Smoothing Losses#

Regularizers that promote smooth trajectories or label distributions over time.

smooth_labels

Apply label smoothing regularization to one-hot encoded labels.