MaskedT#

class brainstate.nn.MaskedT(mask, transform)#

Selective transformation using a boolean mask.

This transformation applies a given transformation only to elements specified by a boolean mask, leaving other elements unchanged. This is useful when only a subset of parameters need to be transformed while others should remain in their original domain.

The transformation is defined by:

\[\begin{split}\text{forward}(x)_i = \begin{cases} f(x_i) & \text{if } \text{mask}_i = \text{True} \\ x_i & \text{if } \text{mask}_i = \text{False} \end{cases}\end{split}\]

where f is the underlying transformation.

The inverse follows the same pattern:

\[\begin{split}\text{inverse}(y)_i = \begin{cases} f^{-1}(y_i) & \text{if } \text{mask}_i = \text{True} \\ y_i & \text{if } \text{mask}_i = \text{False} \end{cases}\end{split}\]
Parameters:
  • mask (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Boolean array indicating which elements to transform.

  • transform (Transform) – The transformation to apply to masked elements.

mask#

Boolean mask array.

Type:

array_like

transform#

The underlying transformation.

Type:

Transform

Notes

The mask and input arrays must have compatible shapes for broadcasting. This transformation is particularly useful in:

  • Mixed parameter models where some parameters are bounded and others are not

  • Selective application of constraints in optimization

  • Sparse transformations where only specific elements need modification

Examples

>>> # Transform only positive indices to be positive
>>> mask = jnp.array([False, True, False, True])
>>> softplus = SoftplusT(0)
>>> masked_transform = MaskedT(mask, softplus)
>>> x = jnp.array([-1.0, -1.0, 2.0, 2.0])
>>> y = masked_transform.forward(x)
>>> # y ≈ [-1.0, 0.31, 2.0, 2.13] (only indices 1,3 transformed)
>>> # Transform correlation parameters but not mean parameters
>>> n_params = 5
>>> corr_mask = jnp.arange(n_params) >= 3  # Last 2 are correlations
>>> sigmoid = SigmoidT(-1, 1)
>>> transform = MaskedT(corr_mask, sigmoid)
forward(x)[source]#

Apply transformation selectively based on mask.

Parameters:

x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input values to transform.

Returns:

Array where masked elements are transformed and unmasked elements remain unchanged.

Return type:

Array

Notes

Uses element-wise conditional logic to apply transformation only where mask is True.

inverse(y)[source]#

Apply inverse transformation selectively based on mask.

Parameters:

y (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Transformed values to invert.

Returns:

Array where masked elements are inverse-transformed and unmasked elements remain unchanged.

Return type:

Array

Notes

Applies inverse transformation only to elements where mask is True, maintaining consistency with the forward operation.