# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Shared neural network activations and other functions.
"""
from typing import Any, Union, Sequence
import brainunit as u
import jax
from jax.scipy.special import logsumexp
from brainstate import random
from brainstate.typing import ArrayLike
__all__ = [
"tanh",
"relu",
"squareplus",
"softplus",
"soft_sign",
"sigmoid",
"silu",
"swish",
"log_sigmoid",
"elu",
"leaky_relu",
"hard_tanh",
"celu",
"selu",
"gelu",
"glu",
"logsumexp",
"log_softmax",
"softmax",
"standardize",
"one_hot",
"relu6",
"hard_sigmoid",
"hard_silu",
"hard_swish",
'hard_shrink',
'rrelu',
'mish',
'soft_shrink',
'prelu',
'tanh_shrink',
'softmin',
'sparse_plus',
'sparse_sigmoid',
]
[docs]
def tanh(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Hyperbolic tangent activation function.
Computes the element-wise function:
.. math::
\mathrm{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
"""
return u.math.tanh(x)
[docs]
def softmin(x, axis=-1):
r"""
Softmin activation function.
Applies the Softmin function to an n-dimensional input tensor, rescaling elements
so that they lie in the range [0, 1] and sum to 1 along the specified axis.
.. math::
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
Parameters
----------
x : ArrayLike
Input array of any shape.
axis : int, optional
The axis along which Softmin will be computed. Every slice along this
dimension will sum to 1. Default is -1.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
"""
unnormalized = u.math.exp(-x)
return unnormalized / unnormalized.sum(axis, keepdims=True)
[docs]
def tanh_shrink(x):
r"""
Tanh shrink activation function.
Applies the element-wise function:
.. math::
\text{Tanhshrink}(x) = x - \tanh(x)
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
"""
return x - u.math.tanh(x)
[docs]
def prelu(x, a=0.25):
r"""
Parametric Rectified Linear Unit activation function.
Applies the element-wise function:
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
or equivalently:
.. math::
\text{PReLU}(x) =
\begin{cases}
x, & \text{ if } x \geq 0 \\
ax, & \text{ otherwise }
\end{cases}
Parameters
----------
x : ArrayLike
Input array.
a : float or ArrayLike, optional
The negative slope coefficient. Can be a learnable parameter.
Default is 0.25.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
Notes
-----
When used in neural network layers, :math:`a` can be a learnable parameter
that is optimized during training.
"""
return u.math.where(x >= 0., x, a * x)
[docs]
def soft_shrink(x, lambd=0.5):
r"""
Soft shrinkage activation function.
Applies the soft shrinkage function element-wise:
.. math::
\text{SoftShrinkage}(x) =
\begin{cases}
x - \lambda, & \text{ if } x > \lambda \\
x + \lambda, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Parameters
----------
x : ArrayLike
Input array of any shape.
lambd : float, optional
The :math:`\lambda` value for the soft shrinkage formulation.
Must be non-negative. Default is 0.5.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
"""
return u.math.where(
x > lambd,
x - lambd,
u.math.where(
x < -lambd,
x + lambd,
u.Quantity(0., unit=u.get_unit(lambd))
)
)
[docs]
def mish(x):
r"""
Mish activation function.
Mish is a self-regularized non-monotonic activation function.
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
Parameters
----------
x : ArrayLike
Input array of any shape.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
References
----------
.. [1] Misra, D. (2019). "Mish: A Self Regularized Non-Monotonic Activation Function."
arXiv:1908.08681
"""
return x * u.math.tanh(softplus(x))
[docs]
def rrelu(x, lower=0.125, upper=0.3333333333333333):
r"""
Randomized Leaky Rectified Linear Unit activation function.
The function is defined as:
.. math::
\text{RReLU}(x) =
\begin{cases}
x & \text{if } x \geq 0 \\
ax & \text{ otherwise }
\end{cases}
where :math:`a` is randomly sampled from uniform distribution
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
Parameters
----------
x : ArrayLike
Input array of any shape.
lower : float, optional
Lower bound of the uniform distribution for sampling the negative slope.
Default is 1/8.
upper : float, optional
Upper bound of the uniform distribution for sampling the negative slope.
Default is 1/3.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
References
----------
.. [1] Xu, B., et al. (2015). "Empirical Evaluation of Rectified Activations
in Convolutional Network." arXiv:1505.00853
"""
a = random.uniform(lower, upper, size=u.math.shape(x), dtype=x.dtype)
return u.math.where(u.get_mantissa(x) >= 0., x, a * x)
[docs]
def hard_shrink(x, lambd=0.5):
r"""
Hard shrinkage activation function.
Applies the hard shrinkage function element-wise:
.. math::
\text{HardShrink}(x) =
\begin{cases}
x, & \text{ if } x > \lambda \\
x, & \text{ if } x < -\lambda \\
0, & \text{ otherwise }
\end{cases}
Parameters
----------
x : ArrayLike
Input array of any shape.
lambd : float, optional
The :math:`\lambda` threshold value for the hard shrinkage formulation.
Default is 0.5.
Returns
-------
jax.Array or Quantity
Output array with the same shape as the input.
"""
return u.math.where(
x > lambd,
x,
u.math.where(
x < -lambd,
x,
u.Quantity(0., unit=u.get_unit(x))
)
)
[docs]
def relu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Rectified Linear Unit activation function.
Computes the element-wise function:
.. math::
\mathrm{relu}(x) = \max(x, 0)
Under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import brainstate
>>> brainstate.nn.relu(jnp.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
See Also
--------
relu6 : ReLU6 activation function.
leaky_relu : Leaky ReLU activation function.
References
----------
.. [1] For more information see "Numerical influence of ReLU'(0) on backpropagation"
https://openreview.net/forum?id=urrcVI-_jRm
"""
return u.math.relu(x)
[docs]
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Union[jax.Array, u.Quantity]:
r"""
Squareplus activation function.
Computes the element-wise function:
.. math::
\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}
Parameters
----------
x : ArrayLike
Input array.
b : ArrayLike, optional
Smoothness parameter. Default is 4.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
References
----------
.. [1] So, D., et al. (2021). "Primer: Searching for Efficient Transformers
for Language Modeling." arXiv:2112.11687
"""
return u.math.squareplus(x, b=b)
[docs]
def softplus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Softplus activation function.
Computes the element-wise function:
.. math::
\mathrm{softplus}(x) = \log(1 + e^x)
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
"""
return u.math.softplus(x)
[docs]
def soft_sign(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Soft-sign activation function.
Computes the element-wise function:
.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
"""
return u.math.soft_sign(x)
[docs]
def sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
log_sigmoid : Logarithm of the sigmoid function.
"""
return u.math.sigmoid(x)
[docs]
def silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
SiLU (Sigmoid Linear Unit) activation function.
Computes the element-wise function:
.. math::
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
sigmoid : The sigmoid function.
swish : Alias for silu.
Notes
-----
`swish` and `silu` are both aliases for the same function.
"""
return u.math.silu(x)
swish = silu
[docs]
def log_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Log-sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
sigmoid : The sigmoid function.
"""
return u.math.log_sigmoid(x)
[docs]
def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
r"""
Exponential Linear Unit activation function.
Computes the element-wise function:
.. math::
\mathrm{elu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(x) - 1\right), & x \le 0
\end{cases}
Parameters
----------
x : ArrayLike
Input array.
alpha : ArrayLike, optional
Scalar or array of alpha values. Default is 1.0.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
selu : Scaled ELU activation function.
celu : Continuously-differentiable ELU activation function.
"""
return u.math.elu(x, alpha=alpha)
[docs]
def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Union[jax.Array, u.Quantity]:
r"""
Leaky Rectified Linear Unit activation function.
Computes the element-wise function:
.. math::
\mathrm{leaky\_relu}(x) = \begin{cases}
x, & x \ge 0\\
\alpha x, & x < 0
\end{cases}
where :math:`\alpha` = :code:`negative_slope`.
Parameters
----------
x : ArrayLike
Input array.
negative_slope : ArrayLike, optional
Array or scalar specifying the negative slope. Default is 0.01.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
relu : Standard ReLU activation function.
prelu : Parametric ReLU with learnable slope.
"""
return u.math.leaky_relu(x, negative_slope=negative_slope)
def _hard_tanh(x, min_val=- 1.0, max_val=1.0):
return jax.numpy.where(x > max_val, max_val, jax.numpy.where(x < min_val, min_val, x))
[docs]
def hard_tanh(
x: ArrayLike,
min_val: float = - 1.0,
max_val: float = 1.0
) -> Union[jax.Array, u.Quantity]:
r"""
Hard hyperbolic tangent activation function.
Computes the element-wise function:
.. math::
\mathrm{hard\_tanh}(x) = \begin{cases}
-1, & x < -1\\
x, & -1 \le x \le 1\\
1, & 1 < x
\end{cases}
Parameters
----------
x : ArrayLike
Input array.
min_val : float, optional
Minimum value of the linear region range. Default is -1.
max_val : float, optional
Maximum value of the linear region range. Default is 1.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
"""
x = u.Quantity(x)
min_val = u.Quantity(min_val).to(x.unit).mantissa
max_val = u.Quantity(max_val).to(x.unit).mantissa
return u.maybe_decimal(_hard_tanh(x.mantissa, min_val=min_val, max_val=max_val) * x.unit)
[docs]
def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Union[jax.Array, u.Quantity]:
r"""
Continuously-differentiable Exponential Linear Unit activation.
Computes the element-wise function:
.. math::
\mathrm{celu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
\end{cases}
Parameters
----------
x : ArrayLike
Input array.
alpha : ArrayLike, optional
Scalar or array value controlling the smoothness. Default is 1.0.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
References
----------
.. [1] Barron, J. T. (2017). "Continuously Differentiable Exponential Linear Units."
arXiv:1704.07483
"""
return u.math.celu(x, alpha=alpha)
[docs]
def selu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Scaled Exponential Linear Unit activation.
Computes the element-wise function:
.. math::
\mathrm{selu}(x) = \lambda \begin{cases}
x, & x > 0\\
\alpha e^x - \alpha, & x \le 0
\end{cases}
where :math:`\lambda = 1.0507009873554804934193349852946` and
:math:`\alpha = 1.6732632423543772848170429916717`.
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
elu : Exponential Linear Unit activation function.
References
----------
.. [1] Klambauer, G., et al. (2017). "Self-Normalizing Neural Networks."
NeurIPS 2017.
"""
return u.math.selu(x)
[docs]
def gelu(x: ArrayLike, approximate: bool = True) -> Union[jax.Array, u.Quantity]:
r"""
Gaussian Error Linear Unit activation function.
If ``approximate=False``, computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
\frac{x}{\sqrt{2}} \right) \right)
If ``approximate=True``, uses the approximate formulation of GELU:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
Parameters
----------
x : ArrayLike
Input array.
approximate : bool, optional
Whether to use the approximate (True) or exact (False) formulation.
Default is True.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
References
----------
.. [1] Hendrycks, D., & Gimpel, K. (2016). "Gaussian Error Linear Units (GELUs)."
arXiv:1606.08415
"""
return u.math.gelu(x, approximate=approximate)
[docs]
def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
r"""
Gated Linear Unit activation function.
Computes the function:
.. math::
\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot
\mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right]
\right)
where the array is split into two along ``axis``. The size of the ``axis``
dimension must be divisible by two.
Parameters
----------
x : ArrayLike
Input array. The dimension specified by ``axis`` must be divisible by 2.
axis : int, optional
The axis along which the split should be computed. Default is -1.
Returns
-------
jax.Array or Quantity
An array with the same shape as input except the ``axis`` dimension
is halved.
See Also
--------
sigmoid : The sigmoid activation function.
"""
return u.math.glu(x, axis=axis)
[docs]
def log_softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
r"""
Log-Softmax function.
Computes the logarithm of the softmax function, which rescales
elements to the range :math:`[-\infty, 0)`.
.. math ::
\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)
Parameters
----------
x : ArrayLike
Input array.
axis : int or tuple of int, optional
The axis or axes along which the log-softmax should be computed.
Either an integer or a tuple of integers. Default is -1.
where : ArrayLike, optional
Elements to include in the log-softmax computation.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
softmax : The softmax function.
"""
return jax.nn.log_softmax(x, axis=axis, where=where)
[docs]
def softmax(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
r"""
Softmax activation function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Parameters
----------
x : ArrayLike
Input array.
axis : int or tuple of int, optional
The axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers. Default is -1.
where : ArrayLike, optional
Elements to include in the softmax computation.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
log_softmax : Logarithm of the softmax function.
softmin : Softmin activation function.
"""
return jax.nn.softmax(x, axis=axis, where=where)
[docs]
def standardize(x: ArrayLike,
axis: int | tuple[int, ...] | None = -1,
variance: ArrayLike | None = None,
epsilon: ArrayLike = 1e-5,
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
r"""
Standardize (normalize) an array.
Normalizes an array by subtracting the mean and dividing by the standard
deviation :math:`\sqrt{\mathrm{variance}}`.
Parameters
----------
x : ArrayLike
Input array.
axis : int or tuple of int, optional
The axis or axes along which to compute the mean and variance.
Default is -1.
variance : ArrayLike, optional
Pre-computed variance. If None, variance is computed from ``x``.
epsilon : ArrayLike, optional
A small constant added to the variance to avoid division by zero.
Default is 1e-5.
where : ArrayLike, optional
Elements to include in the computation.
Returns
-------
jax.Array or Quantity
Standardized array with the same shape as the input.
"""
return jax.nn.standardize(x, axis=axis, where=where, variance=variance, epsilon=epsilon)
[docs]
def one_hot(x: Any,
num_classes: int, *,
dtype: Any = jax.numpy.float_,
axis: Union[int, Sequence[int]] = -1) -> Union[jax.Array, u.Quantity]:
"""
One-hot encode the given indices.
Each index in the input ``x`` is encoded as a vector of zeros of length
``num_classes`` with the element at ``index`` set to one.
Indices outside the range [0, num_classes) will be encoded as zeros.
Parameters
----------
x : ArrayLike
A tensor of indices.
num_classes : int
Number of classes in the one-hot dimension.
dtype : dtype, optional
The dtype for the returned values. Default is ``jnp.float_``.
axis : int or Sequence of int, optional
The axis or axes along which the function should be computed.
Default is -1.
Returns
-------
jax.Array or Quantity
One-hot encoded array.
Examples
--------
.. code-block:: python
>>> import jax.numpy as jnp
>>> import brainstate
>>> brainstate.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)
>>> # Indices outside the range are encoded as zeros
>>> brainstate.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
return jax.nn.one_hot(x, axis=axis, num_classes=num_classes, dtype=dtype)
[docs]
def relu6(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Rectified Linear Unit 6 activation function.
Computes the element-wise function:
.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
Under differentiation, we take:
.. math::
\nabla \mathrm{relu}(0) = 0
and
.. math::
\nabla \mathrm{relu}(6) = 0
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
relu : Standard ReLU activation function.
"""
return u.math.relu6(x)
[docs]
def hard_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Hard Sigmoid activation function.
Computes the element-wise function:
.. math::
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
relu6 : ReLU6 activation function.
sigmoid : Standard sigmoid function.
"""
return u.math.hard_sigmoid(x)
[docs]
def hard_silu(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Hard SiLU (Swish) activation function.
Computes the element-wise function:
.. math::
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
hard_sigmoid : Hard sigmoid activation function.
silu : Standard SiLU activation function.
hard_swish : Alias for hard_silu.
Notes
-----
Both `hard_silu` and `hard_swish` are aliases for the same function.
"""
return u.math.hard_silu(x)
hard_swish = hard_silu
[docs]
def sparse_plus(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Sparse plus activation function.
Computes the function:
.. math::
\mathrm{sparse\_plus}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{4}(x+1)^2, & -1 < x < 1 \\
x, & 1 \leq x
\end{cases}
This is the twin function of the softplus activation, ensuring a zero output
for inputs less than -1 and a linear output for inputs greater than 1,
while remaining smooth, convex, and monotonic between -1 and 1.
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
sparse_sigmoid : Derivative of sparse_plus.
softplus : Standard softplus activation function.
"""
return u.math.sparse_plus(x)
[docs]
def sparse_sigmoid(x: ArrayLike) -> Union[jax.Array, u.Quantity]:
r"""
Sparse sigmoid activation function.
Computes the function:
.. math::
\mathrm{sparse\_sigmoid}(x) = \begin{cases}
0, & x \leq -1\\
\frac{1}{2}(x+1), & -1 < x < 1 \\
1, & 1 \leq x
\end{cases}
This is the twin function of the standard sigmoid activation, ensuring a zero
output for inputs less than -1, a 1 output for inputs greater than 1, and a
linear output for inputs between -1 and 1. It is the derivative of `sparse_plus`.
Parameters
----------
x : ArrayLike
Input array.
Returns
-------
jax.Array or Quantity
An array with the same shape as the input.
See Also
--------
sigmoid : Standard sigmoid activation function.
sparse_plus : Sparse plus activation function.
References
----------
.. [1] Martins, A. F. T., & Astudillo, R. F. (2016). "From Softmax to Sparsemax:
A Sparse Model of Attention and Multi-Label Classification."
In ICML. See also "Learning with Fenchel-Young Losses", arXiv:1901.02324
"""
return u.math.sparse_sigmoid(x)