glu

Contents

glu#

class brainunit.math.glu(x, axis=-1, unit_to_scale=None)#

Gated linear unit activation function.

Computes the function:

\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]

where the array is split into two along axis. The size of the axis dimension must be divisible by two.

Parameters:
  • x (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Input array. Must be unitless if a Quantity. The size of the dimension specified by axis must be even.

  • axis (int) – The axis along which to split the input. Default is -1.

  • unit_to_scale (Unit | None) – Unit used to convert x to a dimensionless number before applying the activation.

Returns:

out – An array whose size along axis is half that of the input.

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> x = jnp.array([[1., 2., 3., 4.]])
>>> sumath.glu(x)
Array([[0.95257413, 1.9640275 ]], dtype=float32)