squared_error#
- class braintools.metric.squared_error(predictions, targets=None, axis=None, reduction='none')#
Compute element-wise squared error between predictions and targets.
Calculates the squared differences between predicted and target values, which forms the basis for Mean Squared Error (MSE) and related regression metrics. This is one of the most fundamental loss functions in machine learning, particularly effective for regression tasks.
The squared error is defined as:
\[\text{SE}(y, \hat{y}) = (y - \hat{y})^2\]where \(y\) are the true values and \(\hat{y}\) are the predictions.
- Parameters:
predictions (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Predicted values with arbitrary shape. Must be floating-point type.targets (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Ground truth target values with shape broadcastable topredictions. If not provided, targets are assumed to be zeros, making this equivalent to computing the squared magnitude of predictions.axis (
int|tuple[int,...] |None) – Axis or axes along which to reduce the error. If None, no reduction is performed and element-wise errors are returned.reduction (
str) –Reduction operation to apply:
'none': Return element-wise errors without reduction'mean': Return mean of errors (MSE when no axis specified)'sum': Return sum of errors
- Returns:
Squared errors. Shape depends on
axisandreductionparameters:If
reduction='none': same shape aspredictionsIf reduction is applied: reduced according to
axisparameter
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Notes
This function is closely related to L2 loss, with the relationship:
\[\text{L2 loss} = \frac{1}{2} \times \text{squared error}\]The factor of 0.5 is conventional in some textbooks (e.g., Bishop’s “Pattern Recognition and Machine Learning”) but not others (e.g., “The Elements of Statistical Learning” by Hastie, Tibshirani, and Friedman).
Mean Squared Error (MSE) is computed as
squared_error(pred, target, reduction='mean').Examples
Basic element-wise squared error:
>>> import jax.numpy as jnp >>> import braintools >>> predictions = jnp.array([1.0, 2.0, 3.0]) >>> targets = jnp.array([1.1, 1.9, 3.2]) >>> errors = braintools.metric.squared_error(predictions, targets) >>> print(errors) # [0.01, 0.01, 0.04]
Mean Squared Error:
>>> mse = braintools.metric.squared_error(predictions, targets, reduction='mean') >>> print(f"MSE: {mse:.4f}")
Squared error with missing targets (assuming zero targets):
>>> pred_only = jnp.array([0.5, -0.3, 0.8]) >>> sq_magnitude = braintools.metric.squared_error(pred_only) >>> print(sq_magnitude) # [0.25, 0.09, 0.64]
Batch processing with axis reduction:
>>> batch_pred = jnp.array([[1.0, 2.0], [3.0, 4.0]]) >>> batch_targets = jnp.array([[1.1, 1.9], [2.8, 4.2]]) >>> # MSE per sample >>> per_sample_mse = braintools.metric.squared_error(batch_pred, batch_targets, ... axis=1, reduction='mean') >>> print(per_sample_mse)
See also
braintools.metric.absolute_errorL1 loss alternative
braintools.metric.l2_lossSquared error scaled by 0.5
braintools.metric.huber_lossRobust alternative combining L1 and L2
References