ChainT#

class brainstate.nn.ChainT(*transforms)#

Composition of multiple transformations applied sequentially.

This class implements the mathematical composition of functions, allowing multiple transformations to be chained together. The transformations are applied in the order specified during initialization for the forward pass, and in reverse order for the inverse pass.

For transformations f₁, f₂, …, fₙ, the chain implements:

\[\text{forward}(x) = f_n(f_{n-1}(...f_2(f_1(x))...))\]
\[\text{inverse}(y) = f_1^{-1}(f_2^{-1}(...f_{n-1}^{-1}(f_n^{-1}(y))...))\]
Parameters:

*transforms (Transform) – Variable number of Transform objects to chain together.

transforms#

Tuple of transformations in the order they will be applied.

Type:

sequence of Transform

Notes

The chain transformation preserves bijectivity if all component transformations are bijective. The Jacobian of the chain is the product of the Jacobians of the individual transformations.

Chain transformations are particularly useful for:

  • Complex parameter constraints requiring multiple steps

  • Modular transformation design

  • Combining simple transformations to achieve complex mappings

Examples

>>> # Transform to (0, 1) then scale to (a, b)
>>> sigmoid = SigmoidT(0, 1)
>>> affine = AffineT(scale=b-a, shift=a)
>>> chain = ChainT(sigmoid, affine)
>>> # Standardize then apply softplus
>>> standardize = AffineT(1/sigma, -mu/sigma)
>>> softplus = SoftplusT(0)
>>> chain = ChainT(standardize, softplus)
forward(x)[source]#

Apply all transformations sequentially in forward order.

Parameters:

x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input values to transform.

Returns:

Values after applying all transformations in sequence.

Return type:

Array

Notes

Transformations are applied left-to-right as specified in initialization.

inverse(y)[source]#

Apply all inverse transformations sequentially in reverse order.

Parameters:

y (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Transformed values to invert.

Returns:

Original values before all transformations were applied.

Return type:

Array

Notes

Transformations are inverted right-to-left (reverse order) to properly undo the forward chain.

log_abs_det_jacobian(x, y)[source]#

Sum of log Jacobian determinants of all transforms in the chain.

Return type:

Array