AffineT#

class brainstate.nn.AffineT(scale, shift)#

Affine (linear) transformation with scaling and shifting.

This transformation applies a linear transformation of the form y = ax + b, where a is the scale factor and b is the shift. It is the most basic form of transformation and preserves the relative ordering of inputs while allowing for rescaling and translation.

The transformation is defined by:

\[\text{forward}(x) = a \cdot x + b\]

The inverse transformation is:

\[\text{inverse}(y) = \frac{y - b}{a}\]
Parameters:
  • scale (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Scaling factor a. Must be non-zero for invertibility.

  • shift (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Additive shift b.

a#

Scaling factor.

Type:

array_like

b#

Shift parameter.

Type:

array_like

Raises:

ValueError – If scale is zero or numerically close to zero, making the transformation non-invertible.

Notes

Affine transformations are the foundation of many statistical transformations. They preserve linearity and are particularly useful for:

  • Standardization: (x - μ) / σ

  • Normalization: (x - min) / (max - min)

  • Unit conversion: x * conversion_factor + offset

The Jacobian of this transformation is constant: |det(J)| = |a|.

Examples

>>> # Standardization transform (z-score)
>>> mu, sigma = 5.0, 2.0
>>> transform = AffineT(1/sigma, -mu/sigma)
>>> x = jnp.array([3.0, 5.0, 7.0])
>>> z = transform.forward(x)
>>> # z ≈ [-1.0, 0.0, 1.0]
>>> # Temperature conversion: Celsius to Fahrenheit
>>> transform = AffineT(9/5, 32)
>>> celsius = jnp.array([0.0, 100.0])
>>> fahrenheit = transform.forward(celsius)
>>> # fahrenheit ≈ [32.0, 212.0]
forward(x)[source]#

Apply the affine transformation y = ax + b.

Parameters:

x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input values to transform.

Returns:

Transformed values after scaling and shifting.

Return type:

Array

inverse(x)[source]#

Apply the inverse affine transformation x = (y - b) / a.

Parameters:

x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Transformed values to invert (note: parameter name kept for consistency).

Returns:

Original values before transformation.

Return type:

Array

log_abs_det_jacobian(x, y)[source]#

For affine: d/dx[ax + b] = a, so log|det J| = n * log|a|.

Return type:

Array