sparse_sigmoid

Contents

sparse_sigmoid#

class saiunit.math.sparse_sigmoid(x, unit_to_scale=None)#

Sparse sigmoid activation function.

Computes the function:

\[\begin{split}\mathrm{sparse\_sigmoid}(x) = \begin{cases} 0, & x \leq -1\\ \frac{1}{2}(x+1), & -1 < x < 1 \\ 1, & 1 \leq x \end{cases}\end{split}\]

This is the twin function of the sigmoid activation ensuring a zero output for inputs less than -1, a 1 output for inputs greater than 1, and a linear output for inputs between -1 and 1. It is the derivative of sparse_plus.

For more information, see Learning with Fenchel-Young Losses (section 6.2).

Parameters:
  • x (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Input array. Must be unitless if a Quantity.

  • unit_to_scale (Unit | None) – Unit used to convert x to a dimensionless number before applying the activation.

Returns:

out – An array with values in the range [0, 1].

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> sumath.sparse_sigmoid(jnp.array([-2., 0., 2.]))
Array([0. , 0.5, 1. ], dtype=float32)