UnitVectorT#

class brainstate.nn.UnitVectorT#

Transformation to unit vectors (L2 norm = 1).

Projects input vectors onto the unit sphere by normalizing. This is useful for directional data or when parameters must lie on a sphere.

The transformation is defined by:

\[\text{forward}(x) = \frac{x}{\|x\|_2}\]

Notes

This transformation is not strictly bijective since all vectors along a ray map to the same unit vector. The inverse returns the input unchanged, assuming it is already on the unit sphere.

Examples

>>> transform = UnitVectorT()
>>> x = jnp.array([3.0, 4.0])
>>> y = transform.forward(x)
>>> # y has unit norm
>>> assert jnp.allclose(jnp.linalg.norm(y), 1.0)
forward(x)[source]#

Project input onto unit sphere.

Return type:

Array

inverse(y)[source]#

Return input unchanged (assumes already on unit sphere).

Return type:

Array