Transform#

class brainstate.nn.Transform#

Abstract base class for bijective parameter transformations.

This class provides the interface for implementing bijective (one-to-one and onto) transformations that map parameters between different domains. These transformations are essential in optimization and statistical inference where parameters need to be constrained to specific domains (e.g., positive values, bounded intervals).

A bijective transformation \(f: \mathcal{X} \rightarrow \mathcal{Y}\) must satisfy:

  1. Injectivity (one-to-one): \(f(x_1) = f(x_2) \Rightarrow x_1 = x_2\)

  2. Surjectivity (onto): \(\forall y \in \mathcal{Y}, \exists x \in \mathcal{X} : f(x) = y\)

  3. Invertibility: \(f^{-1}(f(x)) = x\) and \(f(f^{-1}(y)) = y\)

forward(x)[source]#

Apply the forward transformation \(y = f(x)\)

inverse(y)[source]#

Apply the inverse transformation \(x = f^{-1}(y)\)

log_abs_det_jacobian(x, y)[source]#

Compute the log absolute determinant of the Jacobian

Notes

Subclasses must implement both forward and inverse methods to ensure the transformation is truly bijective. The implementation should guarantee numerical stability and handle edge cases appropriately.

Examples

>>> class SquareTransform(Transform):
...     def forward(self, x):
...         return x**2
...     def inverse(self, y):
...         return jnp.sqrt(y)
abstractmethod forward(x)[source]#

Apply the forward transformation.

Transforms input from the unconstrained domain to the constrained domain. This method implements the mathematical function \(y = f(x)\) where \(x\) is in the unconstrained space and \(y\) is in the target domain.

Parameters:

x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input array in the unconstrained domain.

Returns:

Transformed output in the constrained domain.

Return type:

Array

Notes

Implementations must ensure numerical stability and handle boundary conditions appropriately.

abstractmethod inverse(y)[source]#

Apply the inverse transformation.

Transforms input from the constrained domain back to the unconstrained domain. This method implements the mathematical function \(x = f^{-1}(y)\) where \(y\) is in the constrained space and \(x\) is in the unconstrained domain.

Parameters:

y (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input array in the constrained domain.

Returns:

Transformed output in the unconstrained domain.

Return type:

Array

Notes

Implementations must ensure that inverse(forward(x)) = x for all valid x, and forward(inverse(y)) = y for all y in the target domain.

log_abs_det_jacobian(x, y)[source]#

Compute the log absolute determinant of the Jacobian of the forward transformation.

For a bijective transformation \(f: \mathcal{X} \rightarrow \mathcal{Y}\), this computes:

\[\log \left| \det \frac{\partial f(x)}{\partial x} \right|\]

This is essential for computing probability densities under change of variables and is widely used in normalizing flows and variational inference.

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input in the unconstrained domain.

  • y (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Output in the constrained domain (i.e., y = forward(x)). This parameter is provided for efficiency since it may already be computed.

Returns:

Log absolute determinant of the Jacobian.

Return type:

Array

Notes

The default implementation raises NotImplementedError. Subclasses should override this method to provide an efficient implementation.

For element-wise transformations, the log determinant is simply the sum of log absolute derivatives:

\[\log \left| \det J \right| = \sum_i \log \left| \frac{\partial f(x_i)}{\partial x_i} \right|\]