MaskedT#
- class brainstate.nn.MaskedT(mask, transform)#
Selective transformation using a boolean mask.
This transformation applies a given transformation only to elements specified by a boolean mask, leaving other elements unchanged. This is useful when only a subset of parameters need to be transformed while others should remain in their original domain.
The transformation is defined by:
\[\begin{split}\text{forward}(x)_i = \begin{cases} f(x_i) & \text{if } \text{mask}_i = \text{True} \\ x_i & \text{if } \text{mask}_i = \text{False} \end{cases}\end{split}\]where f is the underlying transformation.
The inverse follows the same pattern:
\[\begin{split}\text{inverse}(y)_i = \begin{cases} f^{-1}(y_i) & \text{if } \text{mask}_i = \text{True} \\ y_i & \text{if } \text{mask}_i = \text{False} \end{cases}\end{split}\]- Parameters:
- mask#
Boolean mask array.
- Type:
array_like
Notes
The mask and input arrays must have compatible shapes for broadcasting. This transformation is particularly useful in:
Mixed parameter models where some parameters are bounded and others are not
Selective application of constraints in optimization
Sparse transformations where only specific elements need modification
Examples
>>> # Transform only positive indices to be positive >>> mask = jnp.array([False, True, False, True]) >>> softplus = SoftplusT(0) >>> masked_transform = MaskedT(mask, softplus) >>> x = jnp.array([-1.0, -1.0, 2.0, 2.0]) >>> y = masked_transform.forward(x) >>> # y ≈ [-1.0, 0.31, 2.0, 2.13] (only indices 1,3 transformed)
>>> # Transform correlation parameters but not mean parameters >>> n_params = 5 >>> corr_mask = jnp.arange(n_params) >= 3 # Last 2 are correlations >>> sigmoid = SigmoidT(-1, 1) >>> transform = MaskedT(corr_mask, sigmoid)
- forward(x)[source]#
Apply transformation selectively based on mask.
- Parameters:
x (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Input values to transform.- Returns:
Array where masked elements are transformed and unmasked elements remain unchanged.
- Return type:
Array
Notes
Uses element-wise conditional logic to apply transformation only where mask is True.
- inverse(y)[source]#
Apply inverse transformation selectively based on mask.
- Parameters:
y (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Transformed values to invert.- Returns:
Array where masked elements are inverse-transformed and unmasked elements remain unchanged.
- Return type:
Array
Notes
Applies inverse transformation only to elements where mask is True, maintaining consistency with the forward operation.