squareplus#
- class brainunit.math.squareplus(x, b=4, unit_to_scale=None)#
Squareplus activation function.
Computes the element-wise function
\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]as described in https://arxiv.org/abs/2112.11687.
- Parameters:
x (
Quantity|Array|ndarray|bool|number|bool|int|float|complex) – Input array. Must be unitless if aQuantity.b (
Array|ndarray|bool|number|bool|int|float|complex) – Smoothness parameter. Default is 4.unit_to_scale (
Unit|None) – Unit used to convertxto a dimensionless number before applying the activation.
- Returns:
out – An array with non-negative values.
- Return type:
Array
Examples
>>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.squareplus(jnp.array([-2., 0., 2.])) Array([0.23606798, 1. , 2.2360680 ], dtype=float32)