Activation Functions#
brainunit.math provides 20+ unit-aware activation functions for neural networks.
These functions fall into two categories:
Piecewise-linear (keep unit):
relu,leaky_relu— work directly with Quantity inputsNonlinear (require unitless):
sigmoid,gelu,softplus,tanh, etc. — require dimensionless input, but supportunit_to_scalefor automatic conversion
import brainunit as u
import jax.numpy as jnp
Activations That Keep Unit#
Piecewise-linear activations preserve the input unit because they only apply thresholding and scaling — no transcendental functions involved.
relu — Rectified Linear Unit#
x = jnp.array([-2., -1., 0., 1., 2.]) * u.volt
print('relu:', u.math.relu(x)) # [0, 0, 0, 1, 2] V
relu: [0. 0. 0. 1. 2.] V
leaky_relu — Leaky ReLU#
# Allows small negative values (default slope = 0.01)
print('leaky_relu:', u.math.leaky_relu(x))
leaky_relu: [-0.02 -0.01 0. 1. 2. ] V
Activations Requiring Dimensionless Input#
Most activation functions (sigmoid, tanh, gelu, etc.) involve exponentials or other transcendental functions that are not physically meaningful with units.
There are two ways to use them:
Dimensionless input: Pass plain arrays or convert to dimensionless first
unit_to_scaleparameter: Automatically converts the Quantity to dimensionless using the given unit
sigmoid#
# Method 1: Dimensionless input
x_unitless = jnp.array([-2., -1., 0., 1., 2.])
print('sigmoid (unitless):', u.math.sigmoid(x_unitless))
sigmoid (unitless): [0.11920292 0.26894143 0.5 0.7310586 0.880797 ]
# Method 2: unit_to_scale — auto-converts Quantity to dimensionless
x_volts = jnp.array([-2., -1., 0., 1., 2.]) * u.volt
print('sigmoid (with unit_to_scale):', u.math.sigmoid(x_volts, unit_to_scale=u.volt))
sigmoid (with unit_to_scale): [0.11920292 0.26894143 0.5 0.7310586 0.880797 ]
# Without unit_to_scale, passing a Quantity raises an error
try:
u.math.sigmoid(x_volts)
except Exception as e:
print(type(e).__name__, ':', str(e)[:120])
TypeError : sigmoid requires a dimensionless "x" when "unit_to_scale" is not provided. Got Quantity(unit=V, dim=m^2 kg s^-3 A^-1). P
tanh#
print('tanh (unitless):', u.math.tanh(x_unitless))
print('tanh (unit_to_scale):', u.math.tanh(x_volts, unit_to_scale=u.volt))
tanh (unitless): [-0.9640276 -0.7615942 0. 0.7615942 0.9640276]
tanh (unit_to_scale): [-0.9640276 -0.7615942 0. 0.7615942 0.9640276]
gelu — Gaussian Error Linear Unit#
print('gelu:', u.math.gelu(x_unitless))
print('gelu (unit_to_scale):', u.math.gelu(x_volts, unit_to_scale=u.volt))
gelu: [-0.04540235 -0.15880796 0. 0.841192 1.9545977 ]
gelu (unit_to_scale): [-0.04540235 -0.15880796 0. 0.841192 1.9545977 ]
softplus#
print('softplus:', u.math.softplus(x_unitless))
print('softplus (unit_to_scale):', u.math.softplus(x_volts, unit_to_scale=u.volt))
softplus: [0.126928 0.3132617 0.6931472 1.3132617 2.126928 ]
softplus (unit_to_scale): [0.126928 0.3132617 0.6931472 1.3132617 2.126928 ]
silu / swish — Sigmoid Linear Unit#
print('silu:', u.math.silu(x_unitless))
print('silu (unit_to_scale):', u.math.silu(x_volts, unit_to_scale=u.volt))
silu: [-0.23840584 -0.26894143 0. 0.7310586 1.761594 ]
silu (unit_to_scale): [-0.23840584 -0.26894143 0. 0.7310586 1.761594 ]
elu — Exponential Linear Unit#
print('elu:', u.math.elu(x_unitless))
print('elu (unit_to_scale):', u.math.elu(x_volts, unit_to_scale=u.volt))
elu: [-0.86466473 -0.63212055 0. 1. 2. ]
elu (unit_to_scale): [-0.86466473 -0.63212055 0. 1. 2. ]
More Activations#
# All these work the same way with unit_to_scale
print('celu:', u.math.celu(x_unitless))
print('selu:', u.math.selu(x_unitless))
print('mish:', u.math.mish(x_unitless))
print('hard_sigmoid:', u.math.hard_sigmoid(x_unitless))
print('hard_tanh:', u.math.hard_tanh(x_unitless))
print('squareplus:', u.math.squareplus(x_unitless))
print('soft_sign:', u.math.soft_sign(x_unitless))
print('log_sigmoid:', u.math.log_sigmoid(x_unitless))
celu: [-0.86466473 -0.63212055 0. 1. 2. ]
selu: [-1.5201665 -1.1113307 0. 1.050701 2.101402 ]
mish: [-0.25250146 -0.30340144 0. 0.8650985 1.943959 ]
hard_sigmoid: [0.16666667 0.33333334 0.5 0.6666667 0.8333334 ]
hard_tanh: [-1. -1. 0. 1. 1.]
squareplus: [0.41421354 0.618034 1. 1.618034 2.4142137 ]
soft_sign: [-0.6666667 -0.5 0. 0.5 0.6666667]
log_sigmoid: [-2.126928 -1.3132617 -0.6931472 -0.3132617 -0.126928 ]
Practical Example: Unit-Aware Neural Network Layer#
In scientific applications, you might want a neural network layer that respects physical units — for example, converting membrane voltages through an activation.
# Simulate a simple neural activation
# Membrane potentials in millivolts
V_membrane = jnp.array([-80., -65., -50., -30., 0., 20.]) * u.mV
# ReLU-based firing rate (keeps unit)
# Neurons fire only for positive membrane potential
firing_rate_relu = u.math.relu(V_membrane)
print('ReLU output:', firing_rate_relu)
# Sigmoid-based firing probability (dimensionless, 0 to 1)
# Convert mV to dimensionless using unit_to_scale
firing_prob = u.math.sigmoid(V_membrane, unit_to_scale=u.mV)
print('Sigmoid probability:', firing_prob)
ReLU output:
[ 0. 0. 0. 0. 0. 20.] mV
Sigmoid probability: [1.8048513e-35 5.9000906e-29 1.9287499e-22 9.3576236e-14 5.0000000e-01
1.0000000e+00]
Summary#
Category |
Functions |
With Units? |
With |
|---|---|---|---|
Keep unit |
|
Yes |
N/A |
Require unitless |
|
No (dimensionless only) |
Yes |