SimplexT#

class brainstate.nn.SimplexT#

Stick-breaking transformation for simplex constraint.

Maps unconstrained ℝⁿ⁻¹ to n-dimensional simplex where all elements are positive and sum to 1. This is useful for probability distributions and categorical parameters.

The stick-breaking process works as follows:

\[\begin{split}z_i = \sigma(x_i) \\ y_i = z_i \cdot \prod_{j<i} (1 - z_j) \quad \text{for } i < n \\ y_n = \prod_{j<n} (1 - z_j)\end{split}\]

where \(\sigma\) is the sigmoid function.

Notes

The input dimension should be n-1 for an n-dimensional simplex output.

Examples

>>> transform = SimplexT()
>>> x = jnp.array([0.0, 0.0])  # 2D input -> 3D simplex output
>>> y = transform.forward(x)
>>> # y sums to 1 and all elements are positive
>>> assert jnp.allclose(jnp.sum(y), 1.0)
>>> assert jnp.all(y > 0)
forward(x)[source]#

Transform unconstrained input to simplex.

Return type:

Array

inverse(y)[source]#

Transform simplex back to unconstrained domain.

Return type:

Array