glu#
- class saiunit.math.glu(x, axis=-1, unit_to_scale=None)#
Gated linear unit activation function.
Computes the function:
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]where the array is split into two along
axis. The size of theaxisdimension must be divisible by two.- Parameters:
x (
Quantity|Array|ndarray|bool|number|bool|int|float|complex) – Input array. Must be unitless if aQuantity. The size of the dimension specified by axis must be even.axis (
int) – The axis along which to split the input. Default is -1.unit_to_scale (
Unit|None) – Unit used to convertxto a dimensionless number before applying the activation.
- Returns:
out – An array whose size along axis is half that of the input.
- Return type:
Array
Examples
>>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> x = jnp.array([[1., 2., 3., 4.]]) >>> sumath.glu(x) Array([[0.95257413, 1.9640275 ]], dtype=float32)