squared_error

Contents

squared_error#

class braintools.metric.squared_error(predictions, targets=None, axis=None, reduction='none')#

Compute element-wise squared error between predictions and targets.

Calculates the squared differences between predicted and target values, which forms the basis for Mean Squared Error (MSE) and related regression metrics. This is one of the most fundamental loss functions in machine learning, particularly effective for regression tasks.

The squared error is defined as:

\[\text{SE}(y, \hat{y}) = (y - \hat{y})^2\]

where \(y\) are the true values and \(\hat{y}\) are the predictions.

Parameters:
  • predictions (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Predicted values with arbitrary shape. Must be floating-point type.

  • targets (Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Ground truth target values with shape broadcastable to predictions. If not provided, targets are assumed to be zeros, making this equivalent to computing the squared magnitude of predictions.

  • axis (int | tuple[int, ...] | None) – Axis or axes along which to reduce the error. If None, no reduction is performed and element-wise errors are returned.

  • reduction (str) –

    Reduction operation to apply:

    • 'none': Return element-wise errors without reduction

    • 'mean': Return mean of errors (MSE when no axis specified)

    • 'sum': Return sum of errors

Returns:

Squared errors. Shape depends on axis and reduction parameters:

  • If reduction='none': same shape as predictions

  • If reduction is applied: reduced according to axis parameter

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Notes

This function is closely related to L2 loss, with the relationship:

\[\text{L2 loss} = \frac{1}{2} \times \text{squared error}\]

The factor of 0.5 is conventional in some textbooks (e.g., Bishop’s “Pattern Recognition and Machine Learning”) but not others (e.g., “The Elements of Statistical Learning” by Hastie, Tibshirani, and Friedman).

Mean Squared Error (MSE) is computed as squared_error(pred, target, reduction='mean').

Examples

Basic element-wise squared error:

>>> import jax.numpy as jnp
>>> import braintools
>>> predictions = jnp.array([1.0, 2.0, 3.0])
>>> targets = jnp.array([1.1, 1.9, 3.2])
>>> errors = braintools.metric.squared_error(predictions, targets)
>>> print(errors)  # [0.01, 0.01, 0.04]

Mean Squared Error:

>>> mse = braintools.metric.squared_error(predictions, targets, reduction='mean')
>>> print(f"MSE: {mse:.4f}")

Squared error with missing targets (assuming zero targets):

>>> pred_only = jnp.array([0.5, -0.3, 0.8])
>>> sq_magnitude = braintools.metric.squared_error(pred_only)
>>> print(sq_magnitude)  # [0.25, 0.09, 0.64]

Batch processing with axis reduction:

>>> batch_pred = jnp.array([[1.0, 2.0], [3.0, 4.0]])
>>> batch_targets = jnp.array([[1.1, 1.9], [2.8, 4.2]])
>>> # MSE per sample
>>> per_sample_mse = braintools.metric.squared_error(batch_pred, batch_targets,
...                                          axis=1, reduction='mean')
>>> print(per_sample_mse)

See also

braintools.metric.absolute_error

L1 loss alternative

braintools.metric.l2_loss

Squared error scaled by 0.5

braintools.metric.huber_loss

Robust alternative combining L1 and L2

References