squareplus

Contents

squareplus#

class saiunit.math.squareplus(x, b=4, unit_to_scale=None)#

Squareplus activation function.

Computes the element-wise function

\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]

as described in https://arxiv.org/abs/2112.11687.

Parameters:
  • x (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Input array. Must be unitless if a Quantity.

  • b (Array | ndarray | bool | number | bool | int | float | complex) – Smoothness parameter. Default is 4.

  • unit_to_scale (Unit | None) – Unit used to convert x to a dimensionless number before applying the activation.

Returns:

out – An array with non-negative values.

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> sumath.squareplus(jnp.array([-2., 0., 2.]))
Array([0.23606798, 1.        , 2.2360680 ], dtype=float32)