SimplexT#
- class brainstate.nn.SimplexT#
Stick-breaking transformation for simplex constraint.
Maps unconstrained ℝⁿ⁻¹ to n-dimensional simplex where all elements are positive and sum to 1. This is useful for probability distributions and categorical parameters.
The stick-breaking process works as follows:
\[\begin{split}z_i = \sigma(x_i) \\ y_i = z_i \cdot \prod_{j<i} (1 - z_j) \quad \text{for } i < n \\ y_n = \prod_{j<n} (1 - z_j)\end{split}\]where \(\sigma\) is the sigmoid function.
Notes
The input dimension should be n-1 for an n-dimensional simplex output.
Examples
>>> transform = SimplexT() >>> x = jnp.array([0.0, 0.0]) # 2D input -> 3D simplex output >>> y = transform.forward(x) >>> # y sums to 1 and all elements are positive >>> assert jnp.allclose(jnp.sum(y), 1.0) >>> assert jnp.all(y > 0)