Source code for brainpy_state._base

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from typing import Any, Union, TypeVar, Callable, Optional

import brainstate
import braintools
import numpy as np
from brainstate.mixin import ParamDescriber
from brainstate.typing import ArrayLike, Size

T = TypeVar('T')

__all__ = [
    'Dynamics', 'Neuron', 'Synapse',
]


def _input_label_start(label: str):
    # unify the input label repr.
    return f'{label} // '


def _input_label_repr(name: str, label: Optional[str] = None):
    # unify the input label repr.
    return name if label is None else (_input_label_start(label) + str(name))


class Dynamics(brainstate.nn.Dynamics):
    __module__ = 'brainpy.state'

    def __init__(self, in_size: Size, name: Optional[str] = None):
        # initialize
        super().__init__(name=name, in_size=in_size)

        # current inputs
        self._current_inputs = None

        # delta inputs
        self._delta_inputs = None

    @property
    def current_inputs(self):
        r"""
        Get the dictionary of current inputs registered with this dynamics model.

        Current inputs represent direct input currents that flow into the model.

        Returns
        -------
        dict or None
            A dictionary mapping keys to current input functions or values,
            or None if no current inputs have been registered.

        See Also
        --------
        add_current_input : Register a new current input
        sum_current_inputs : Apply and sum all current inputs
        delta_inputs : Dictionary of instantaneous change inputs
        """
        return self._current_inputs

    @property
    def delta_inputs(self):
        r"""
        Get the dictionary of delta inputs registered with this dynamics model.

        Delta inputs represent instantaneous changes to state variables (dX/dt).

        Returns
        -------
        dict or None
            A dictionary mapping keys to delta input functions or values,
            or None if no delta inputs have been registered.

        See Also
        --------
        add_delta_input : Register a new delta input
        sum_delta_inputs : Apply and sum all delta inputs
        current_inputs : Dictionary of direct current inputs
        """
        return self._delta_inputs

[docs] def add_current_input( self, key: str, inp: Union[Callable, ArrayLike], label: Optional[str] = None ): r""" Add a current input function or array to the dynamics model. Current inputs represent direct input currents that can be accessed during model updates through the `sum_current_inputs()` method. Parameters ---------- key : str Unique identifier for this current input. Used to retrieve or reference the input later. inp : Union[Callable, ArrayLike] The input data or function that generates input data. - If callable: Will be called during updates with arguments passed to `sum_current_inputs()` - If array-like: Will be applied once and then automatically removed from available inputs label : Optional[str], default=None Optional grouping label for the input. When provided, allows selective processing of inputs by label in `sum_current_inputs()`. Raises ------ ValueError If the key has already been used for a different current input. Notes ----- - Inputs with the same label can be processed together using the `label` parameter in `sum_current_inputs()`. - Non-callable inputs are consumed when used (removed after first use). - Callable inputs persist and can be called repeatedly. See Also -------- sum_current_inputs : Sum all current inputs matching a given label add_delta_input : Add a delta input function or array """ key = _input_label_repr(key, label) if self._current_inputs is None: self._current_inputs = dict() if key in self._current_inputs: if id(self._current_inputs[key]) != id(inp): raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.') self._current_inputs[key] = inp
[docs] def add_delta_input( self, key: str, inp: Union[Callable, ArrayLike], label: Optional[str] = None ): r""" Add a delta input function or array to the dynamics model. Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions). This method registers a function or array that provides delta inputs which will be accessible during model updates through the `sum_delta_inputs()` method. Parameters ---------- key : str Unique identifier for this delta input. Used to retrieve or reference the input later. inp : Union[Callable, ArrayLike] The input data or function that generates input data. - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()` - If array-like: Will be applied once and then automatically removed from available inputs label : Optional[str], default=None Optional grouping label for the input. When provided, allows selective processing of inputs by label in `sum_delta_inputs()`. Raises ------ ValueError If the key has already been used for a different delta input. Notes ----- - Inputs with the same label can be processed together using the `label` parameter in `sum_delta_inputs()`. - Non-callable inputs are consumed when used (removed after first use). - Callable inputs persist and can be called repeatedly. See Also -------- sum_delta_inputs : Sum all delta inputs matching a given label add_current_input : Add a current input function or array """ key = _input_label_repr(key, label) if self._delta_inputs is None: self._delta_inputs = dict() if key in self._delta_inputs: if id(self._delta_inputs[key]) != id(inp): raise ValueError(f'Key "{key}" has been defined and used.') self._delta_inputs[key] = inp
[docs] def get_input(self, key: str): r""" Get a registered input function by its key. Retrieves either a current input or a delta input function that was previously registered with the given key. This method checks both current_inputs and delta_inputs dictionaries for the specified key. Parameters ---------- key : str The unique identifier used when the input function was registered. Returns ------- Callable or ArrayLike The input function or array associated with the given key. Raises ------ ValueError If no input function is found with the specified key in either current_inputs or delta_inputs. See Also -------- add_current_input : Register a current input function add_delta_input : Register a delta input function Examples -------- >>> model = Dynamics(10) >>> model.add_current_input('stimulus', lambda t: np.sin(t)) >>> input_func = model.get_input('stimulus') >>> input_func(0.5) # Returns sin(0.5) """ if self._current_inputs is not None and key in self._current_inputs: return self._current_inputs[key] elif self._delta_inputs is not None and key in self._delta_inputs: return self._delta_inputs[key] else: raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
[docs] def sum_current_inputs( self, init: Any, *args, label: Optional[str] = None, pop: bool = True, **kwargs ): r""" Summarize all current inputs by applying and summing all registered current input functions. This method iterates through all registered current input functions (from `.current_inputs`) and applies them to calculate the total input current for the dynamics model. It adds all results to the initial value provided. Parameters ---------- init : Any The initial value to which all current inputs will be added. *args Variable length argument list passed to each current input function. label : Optional[str], default=None If provided, only process current inputs with this label prefix. When None, process all current inputs regardless of label. **kwargs Arbitrary keyword arguments passed to each current input function. Returns ------- Any The initial value plus all applicable current inputs summed together. Notes ----- - Non-callable current inputs are applied once and then automatically removed from the current_inputs dictionary. - Callable current inputs remain registered for subsequent calls. - When a label is provided, only current inputs with keys starting with that label are applied. """ if self._current_inputs is None: return init if label is None: filter_fn = lambda k: True else: label_repr = _input_label_start(label) filter_fn = lambda k: k.startswith(label_repr) for key in tuple(self._current_inputs.keys()): if filter_fn(key): out = self._current_inputs[key] if callable(out): try: init = init + out(*args, **kwargs) except Exception as e: raise ValueError( f'Error in current input value {key}: {out}\n' f'Error: {e}' ) from e else: try: init = init + out except Exception as e: raise ValueError( f'Error in current input value {key}: {out}\n' f'Error: {e}' ) from e if pop: self._current_inputs.pop(key) return init
[docs] def sum_delta_inputs( self, init: Any, *args, label: Optional[str] = None, pop: bool = True, **kwargs ): r""" Summarize all delta inputs by applying and summing all registered delta input functions. This method iterates through all registered delta input functions (from `.delta_inputs`) and applies them to calculate instantaneous changes to model states. It adds all results to the initial value provided. Parameters ---------- init : Any The initial value to which all delta inputs will be added. *args Variable length argument list passed to each delta input function. label : Optional[str], default=None If provided, only process delta inputs with this label prefix. When None, process all delta inputs regardless of label. **kwargs Arbitrary keyword arguments passed to each delta input function. Returns ------- Any The initial value plus all applicable delta inputs summed together. Notes ----- - Non-callable delta inputs are applied once and then automatically removed from the delta_inputs dictionary. - Callable delta inputs remain registered for subsequent calls. - When a label is provided, only delta inputs with keys starting with that label are applied. """ if self._delta_inputs is None: return init if label is None: filter_fn = lambda k: True else: label_repr = _input_label_start(label) filter_fn = lambda k: k.startswith(label_repr) for key in tuple(self._delta_inputs.keys()): if filter_fn(key): out = self._delta_inputs[key] if callable(out): try: init = init + out(*args, **kwargs) except Exception as e: raise ValueError( f'Error in delta input function {key}: {out}\n' f'Error: {e}' ) from e else: try: init = init + out except Exception as e: raise ValueError( f'Error in delta input value {key}: {out}\n' f'Error: {e}' ) from e if pop: self._delta_inputs.pop(key) return init
[docs] def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T: r""" Registers a dynamics module to execute after this module. This method establishes a sequential execution relationship where the specified dynamics module will be called after this module completes its update. This creates a feed-forward connection in the computational graph. Parameters ---------- dyn : Union[ParamDescriber[T], T] The dynamics module to be executed after this module. Can be either: - An instance of Dynamics - A ParamDescriber that can instantiate a Dynamics object Returns ------- T The dynamics module that was registered, allowing for method chaining. Raises ------ TypeError If the input is not a Dynamics instance or a ParamDescriber that creates a Dynamics instance. Examples -------- >>> import brainstate >>> n1 = brainpy.state.LIF(10) >>> n1.align_pre(brainpy.state.Expon.desc(n1.varshape)) # n2 will run after n1 """ if isinstance(dyn, Dynamics): self.add_after_update(id(dyn), dyn) return dyn elif isinstance(dyn, ParamDescriber): if not issubclass(dyn.cls, Dynamics): raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.') if not self.has_after_update(dyn.identifier): self.add_after_update( dyn.identifier, dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape) ) return self.get_after_update(dyn.identifier) else: raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
class Neuron(Dynamics): r""" Base class for all spiking neuron models. This abstract class serves as the foundation for implementing various spiking neuron models in the BrainPy framework. It extends the ``brainpy.state.Dynamics`` class and provides common functionality for spike generation, membrane potential dynamics, and surrogate gradient handling required for training spiking neural networks. All neuron models (e.g., IF, LIF, LIFRef, ALIF) should inherit from this class and implement the required abstract methods, particularly ``get_spike()`` which defines the spike generation mechanism. Parameters ---------- in_size : Size Size of the input to the neuron layer. Can be an integer for 1D input or a tuple for multi-dimensional input (e.g., ``100`` or ``(28, 28)``). spk_fun : Callable, optional Surrogate gradient function for the non-differentiable spike generation operation. Default is ``braintools.surrogate.InvSquareGrad()``. Common alternatives include: - ``braintools.surrogate.ReluGrad()`` - ``braintools.surrogate.SigmoidGrad()`` - ``braintools.surrogate.GaussianGrad()`` - ``braintools.surrogate.ATan()`` spk_reset : str, optional Reset mechanism applied after spike generation. Default is ``'soft'``. - ``'soft'``: Subtract threshold from membrane potential (``V = V - V_th``). This allows for more biological realism and better gradient flow. - ``'hard'``: Apply strict reset using ``jax.lax.stop_gradient`` to set voltage to reset value (``V = V_reset``). name : str, optional Name identifier for the neuron layer. If ``None``, an automatic name will be generated. Useful for debugging and visualization. Attributes ---------- spk_reset : str The reset mechanism used by the neuron. spk_fun : Callable The surrogate gradient function used for spike generation. Notes ----- **Surrogate Gradients** The spike generation operation is inherently non-differentiable (a threshold function), which poses challenges for gradient-based learning. Surrogate gradients provide a differentiable approximation during the backward pass while maintaining the discrete spike behavior during the forward pass. This is crucial for training SNNs with backpropagation through time (BPTT). **Reset Mechanisms** - **Soft Reset**: More biologically plausible as it preserves information about how far above threshold the membrane potential was. This can encode information in the residual voltage and often leads to better gradient flow. - **Hard Reset**: Provides a clean reset to a fixed value, which can be easier to analyze mathematically but may lead to vanishing gradients in deep networks. **State Management** Neuron models typically maintain state variables (e.g., membrane potential ``V``, adaptation current ``a``) as ``brainstate.HiddenState`` objects. These states are: - Initialized via ``init_state(batch_size=None, **kwargs)`` - Reset via ``reset_state(batch_size=None, **kwargs)`` - Updated via ``update(x)`` which returns spikes for the current timestep Examples -------- **Creating a Custom Neuron Model** .. code-block:: python >>> import brainstate >>> import saiunit as u >>> import braintools >>> import brainpy >>> >>> class SimpleNeuron(brainpy.state.Neuron): ... def __init__(self, in_size, V_th=1.0*u.mV, **kwargs): ... super().__init__(in_size, **kwargs) ... self.V_th = V_th ... ... def init_state(self, batch_size=None, **kwargs): ... self.V = brainstate.HiddenState( ... braintools.init.param( ... braintools.init.Constant(0.*u.mV), ... self.varshape, ... batch_size ... ) ... ) ... ... def reset_state(self, batch_size=None, **kwargs): ... self.V.value = braintools.init.param( ... braintools.init.Constant(0.*u.mV), ... self.varshape, ... batch_size ... ) ... ... def get_spike(self, V=None): ... V = self.V.value if V is None else V ... return self.spk_fun((V - self.V_th) / self.V_th) ... ... def update(self, x): ... self.V.value += x ... return self.get_spike() **Using Built-in Neuron Models** .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> # Create a LIF neuron layer >>> neuron = brainpy.state.LIF( ... in_size=100, ... tau=10*u.ms, ... V_th=1.0*u.mV, ... spk_fun=braintools.surrogate.ReluGrad(), ... spk_reset='soft' ... ) >>> >>> # Initialize state for batch processing >>> neuron.init_state(batch_size=32) >>> >>> # Process input and get spikes >>> input_current = 2.0 * u.mA >>> spikes = neuron.update(input_current) >>> print(spikes.shape) (32, 100) **Building a Multi-Layer Spiking Network** .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> # Create a network with multiple neuron types >>> class SpikingNet(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.layer1 = brainpy.state.LIF(784, tau=5*u.ms) ... self.fc1 = brainstate.nn.Linear(784, 256) ... self.layer2 = brainpy.state.ALIF(256, tau=10*u.ms, tau_a=200*u.ms) ... self.fc2 = brainstate.nn.Linear(256, 10) ... self.layer3 = brainpy.state.LIF(10, tau=8*u.ms) ... ... def __call__(self, x): ... spikes1 = self.layer1.update(x) ... x1 = self.fc1(spikes1) ... spikes2 = self.layer2.update(x1) ... x2 = self.fc2(spikes2) ... spikes3 = self.layer3.update(x2) ... return spikes3 References ---------- .. [1] Neftci, E. O., Mostafa, H., & Zenke, F. (2019). Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks. IEEE Signal Processing Magazine, 36(6), 51-63. .. [2] Zenke, F., & Ganguli, S. (2018). SuperSpike: Supervised learning in multilayer spiking neural networks. Neural computation, 30(6), 1514-1541. .. [3] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. """ __module__ = 'brainpy.state' def __init__( self, in_size: brainstate.typing.Size, spk_fun: Callable = braintools.surrogate.InvSquareGrad(), spk_reset: str = 'soft', name: Optional[str] = None, ): super().__init__(in_size, name=name) self.spk_reset = spk_reset self.spk_fun = spk_fun
[docs] def get_spike(self, *args, **kwargs): r""" Generate spikes based on neuron state variables. This abstract method must be implemented by subclasses to define the spike generation mechanism. The method should use the surrogate gradient function ``self.spk_fun`` to enable gradient-based learning. Parameters ---------- *args Positional arguments (typically state variables like membrane potential) **kwargs Keyword arguments Returns ------- ArrayLike Binary spike tensor where 1 indicates a spike and 0 indicates no spike. Raises ------ NotImplementedError If the subclass does not implement this method. """ raise NotImplementedError
class Synapse(Dynamics): r""" Base class for synapse dynamics. This class serves as the foundation for all synapse models in the BrainPy framework, providing a common interface for implementing various types of synaptic connectivity and transmission mechanisms. Synapses model the transmission of signals (typically spikes) between neurons, including temporal dynamics, plasticity, and neurotransmitter effects. All specific synapse implementations (like Expon, Alpha, DualExpon, AMPA, GABAa, etc.) should inherit from this class and implement the required methods for state management and dynamics update. Parameters ---------- in_size : Size Size of the presynaptic input. Can be an integer for 1D input or a tuple for multi-dimensional input (e.g., ``100`` or ``(10, 10)``). name : str, optional Name identifier for the synapse layer. If ``None``, an automatic name will be generated. Useful for debugging and model inspection. Attributes ---------- varshape : tuple Shape of the synaptic state variables, derived from ``in_size``. See Also -------- Expon : Simple first-order exponential decay synapse model DualExpon : Dual exponential synapse model with separate rise and decay Alpha : Alpha function synapse model AMPA : AMPA receptor-mediated excitatory synapse GABAa : GABAa receptor-mediated inhibitory synapse Notes ----- **Synaptic Dynamics** Synapses implement temporal filtering of presynaptic signals. The dynamics are typically described by differential equations that govern how synaptic conductance or current evolves over time in response to presynaptic spikes. **State Variables** Synapse models typically maintain state variables (e.g., conductance ``g``, gating variables) as ``brainstate.HiddenState`` or ``brainstate.ShortTermState`` objects depending on whether they need to be preserved across simulation episodes. **Integration with Neurons** Synapses are commonly used in conjunction with projection layers or connectivity matrices to model synaptic transmission between neuron populations: - In feedforward networks: Linear layer → Synapse → Neuron - In recurrent networks: Neuron → Linear layer → Synapse → Neuron **Alignment Patterns** Some synapse models inherit from :class:`AlignPost` to enable event-driven computation where synaptic variables are aligned with postsynaptic neurons. This is particularly efficient for sparse connectivity patterns. Examples -------- **Creating a Custom Synapse Model** .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> import braintools >>> >>> class SimpleSynapse(brainpy.state.Synapse): ... def __init__(self, in_size, tau=5.0*u.ms, **kwargs): ... super().__init__(in_size, **kwargs) ... self.tau = braintools.init.param(tau, self.varshape) ... self.g_init = braintools.init.Constant(0.*u.mS) ... ... def init_state(self, batch_size=None, **kwargs): ... self.g = brainstate.HiddenState(braintools.init.param(self.g_init, self.varshape, batch_size)) ... ... def reset_state(self, batch_size=None, **kwargs): ... self.g.value = braintools.init.param(self.g_init, self.varshape, batch_size) ... ... def update(self, x=None): ... # Simple exponential decay: dg/dt = -g/tau + x ... dg = lambda g: -g / self.tau ... self.g.value = brainstate.nn.exp_euler_step(dg, self.g.value) ... if x is not None: ... self.g.value += x ... return self.g.value **Using Built-in Synapse Models** .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> import jax >>> >>> # Create an exponential synapse >>> synapse = brainpy.state.Expon(in_size=100, tau=8.0*u.ms) >>> >>> # Initialize state >>> synapse.init_state(batch_size=32) >>> >>> # Update with presynaptic spikes >>> spikes = jax.random.bernoulli( ... jax.random.PRNGKey(0), ... p=0.1, ... shape=(32, 100) ... ) >>> conductance = synapse.update(spikes * 1.0*u.mS) >>> print(conductance.shape) (32, 100) **Building a Feedforward Spiking Network** .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> class SynapticNetwork(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... # Input layer ... self.input_neurons = brainpy.state.LIF(784, tau=5*u.ms) ... # First hidden layer with synaptic filtering ... self.fc1 = brainstate.nn.Linear(784, 256) ... self.syn1 = brainpy.state.Expon(256, tau=8*u.ms) ... self.hidden1 = brainpy.state.LIF(256, tau=10*u.ms) ... # Second hidden layer with AMPA synapse ... self.fc2 = brainstate.nn.Linear(256, 128) ... self.syn2 = brainpy.state.AMPA(128) ... self.hidden2 = brainpy.state.LIF(128, tau=10*u.ms) ... # Output layer ... self.fc3 = brainstate.nn.Linear(128, 10) ... self.output_neurons = brainpy.state.LIF(10, tau=8*u.ms) ... ... def __call__(self, x): ... # Input layer ... spikes0 = self.input_neurons.update(x) ... # First hidden layer ... current1 = self.fc1(spikes0) ... g1 = self.syn1.update(current1) ... spikes1 = self.hidden1.update(g1) ... # Second hidden layer ... current2 = self.fc2(spikes1) ... g2 = self.syn2.update(current2) ... spikes2 = self.hidden2.update(g2) ... # Output layer ... current3 = self.fc3(spikes2) ... output_spikes = self.output_neurons.update(current3) ... return output_spikes **Recurrent Network with Inhibition** .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> class EINetwork(brainstate.nn.Module): ... def __init__(self, n_exc=800, n_inh=200): ... super().__init__() ... # Excitatory population ... self.exc_neurons = brainpy.state.LIF(n_exc, tau=10*u.ms) ... self.exc_syn = brainpy.state.AMPA(n_exc) ... # Inhibitory population ... self.inh_neurons = brainpy.state.LIF(n_inh, tau=8*u.ms) ... self.inh_syn = brainpy.state.GABAa(n_inh) ... # Connectivity ... self.exc_to_exc = brainstate.nn.Linear(n_exc, n_exc) ... self.exc_to_inh = brainstate.nn.Linear(n_exc, n_inh) ... self.inh_to_exc = brainstate.nn.Linear(n_inh, n_exc) ... self.inh_to_inh = brainstate.nn.Linear(n_inh, n_inh) ... ... def __call__(self, ext_input): ... # Excitatory neurons receive external input and recurrent excitation/inhibition ... exc_current = (ext_input + ... self.exc_to_exc(self.exc_syn.g.value) - ... self.inh_to_exc(self.inh_syn.g.value)) ... exc_spikes = self.exc_neurons.update(exc_current) ... self.exc_syn.update(exc_spikes) ... # Inhibitory neurons receive excitatory input and recurrent inhibition ... inh_current = (self.exc_to_inh(self.exc_syn.g.value) - ... self.inh_to_inh(self.inh_syn.g.value)) ... inh_spikes = self.inh_neurons.update(inh_current) ... self.inh_syn.update(inh_spikes) ... return exc_spikes, inh_spikes References ---------- .. [1] Destexhe, A., Mainen, Z. F., & Sejnowski, T. J. (1994). Synthesis of models for excitable membranes, synaptic transmission and neuromodulation using a common kinetic formalism. Journal of computational neuroscience, 1(3), 195-230. .. [2] Dayan, P., & Abbott, L. F. (2001). Theoretical neuroscience: Computational and mathematical modeling of neural systems. MIT Press. .. [3] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. """ __module__ = 'brainpy.state'