UnitVectorT#
- class brainstate.nn.UnitVectorT#
Transformation to unit vectors (L2 norm = 1).
Projects input vectors onto the unit sphere by normalizing. This is useful for directional data or when parameters must lie on a sphere.
The transformation is defined by:
\[\text{forward}(x) = \frac{x}{\|x\|_2}\]Notes
This transformation is not strictly bijective since all vectors along a ray map to the same unit vector. The inverse returns the input unchanged, assuming it is already on the unit sphere.
Examples
>>> transform = UnitVectorT() >>> x = jnp.array([3.0, 4.0]) >>> y = transform.forward(x) >>> # y has unit norm >>> assert jnp.allclose(jnp.linalg.norm(y), 1.0)