weighted_correlation#
- class braintools.metric.weighted_correlation(x, y, w)#
Compute weighted Pearson correlation between two data series.
Calculates the Pearson correlation coefficient between two variables with weighted observations. This is useful when some data points should contribute more to the correlation calculation than others.
The weighted correlation is computed as:
\[r_w = \frac{\mathrm{Cov}_w(X,Y)}{\sqrt{\mathrm{Var}_w(X) \cdot \mathrm{Var}_w(Y)}}\]where \(\mathrm{Cov}_w\) is the weighted covariance.
- Parameters:
x (brainstate.typing.ArrayLike) – First data series. Must be 1-dimensional.
y (brainstate.typing.ArrayLike) – Second data series. Must be 1-dimensional and same length as x.
w (brainstate.typing.ArrayLike) – Weight vector. Must be 1-dimensional and same length as x and y. Higher weights give more importance to corresponding data points.
- Returns:
Weighted Pearson correlation coefficient between -1 and 1.
- Return type:
- Raises:
ValueError – If any input array is not 1-dimensional or if arrays have different lengths.
Notes
The weighted correlation reduces to the standard Pearson correlation when all weights are equal. Weights should be non-negative; zero weights effectively exclude those data points from the calculation.
For numerical stability, avoid using weights with very large differences in magnitude, as this can lead to precision issues.
Examples
>>> import jax.numpy as jnp >>> import braintools as braintools >>> # Perfect linear relationship >>> x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]) >>> y = jnp.array([2.0, 4.0, 6.0, 8.0, 10.0]) >>> # Weight middle points more heavily >>> w = jnp.array([1.0, 1.0, 2.0, 2.0, 1.0]) >>> corr = braintools.metric.weighted_correlation(x, y, w) >>> print(f"Weighted correlation: {corr:.3f}") >>> >>> # Compare with unweighted correlation >>> unweighted = jnp.corrcoef(x, y)[0, 1] >>> print(f"Unweighted correlation: {unweighted:.3f}") >>> >>> # Example with reliability weights (higher for more reliable measurements) >>> reliability = jnp.array([0.5, 0.8, 0.9, 0.7, 0.6]) >>> corr_reliable = braintools.metric.weighted_correlation(x, y, reliability) >>> print(f"Reliability-weighted: {corr_reliable:.3f}")