Source code for brainpy_state._brainpy.readout

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


import numbers
from typing import Callable

import brainstate
import braintools
import saiunit as u
from brainstate.typing import Size, ArrayLike

__all__ = [
    'LeakyRateReadout',
]


class LeakyRateReadout(brainstate.nn.Module):
    r"""
    Leaky dynamics for the read-out module.

    This module implements a leaky integrator with the following dynamics:

    .. math::
        r_{t} = \alpha r_{t-1} + x_{t} W

    where:

      - :math:`r_{t}` is the output at time t
      - :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
      - :math:`x_{t}` is the input at time t
      - :math:`W` is the weight matrix

    The leaky integrator acts as a low-pass filter, allowing the network
    to maintain memory of past inputs with an exponential decay determined
    by the time constant tau.

    Parameters
    ----------
    in_size : int or sequence of int
        Size of the input dimension(s)
    out_size : int or sequence of int
        Size of the output dimension(s)
    tau : ArrayLike, optional
        Time constant of the leaky dynamics, by default 5ms
    w_init : Callable, optional
        Weight initialization function, by default KaimingNormal()
    name : str, optional
        Name of the module, by default None

    Attributes
    ----------
    decay : float
        Decay factor computed as exp(-dt/tau)
    weight : ParamState
        Weight matrix connecting input to output
    r : HiddenState
        Hidden state representing the output values

    See Also
    --------
    Notes
    -----
    - The decay factor :math:`\alpha = e^{-\Delta t / \tau}` is computed once
      at construction time using the current environment ``dt``.
    - The weight matrix is initialized with Kaiming Normal initialization,
      which is suitable for layers that follow ReLU-like activations.
    - This module does not produce spikes; it outputs continuous rate values,
      making it suitable as the final readout layer in spiking networks
      trained with surrogate gradients.

    References
    ----------
    .. [1] Neftci, E. O., Mostafa, H., & Zenke, F. (2019). Surrogate gradient
           learning in spiking neural networks. IEEE Signal Processing
           Magazine, 36(6), 51-63.

    Examples
    --------
    .. code-block:: python

        >>> import brainpy
        >>> import brainstate
        >>> import saiunit as u
        >>> with brainstate.environ.context(dt=1. * u.ms):
        ...     readout = brainpy.state.LeakyRateReadout(128, 10, tau=5.*u.ms)
        ...     readout.init_state(batch_size=1)
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        out_size: Size,
        tau: ArrayLike = 5. * u.ms,
        w_init: Callable = braintools.init.KaimingNormal(),
        name: str = None,
    ):
        super().__init__(name=name)

        # parameters
        self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
        self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
        self.tau = braintools.init.param(tau, self.in_size)
        self.decay = u.math.exp(-brainstate.environ.get_dt() / self.tau)

        # weights
        self.weight = brainstate.ParamState(brainstate.init.param(w_init, (self.in_size[0], self.out_size[0])))

[docs] def init_state(self, batch_size=None, **kwargs): self.r = brainstate.HiddenState( brainstate.init.param(brainstate.init.Constant(0.), self.out_size, batch_size) )
[docs] def reset_state(self, batch_size=None, **kwargs): self.r.value = brainstate.init.param( brainstate.init.Constant(0.), self.out_size, batch_size )
def update(self, x): self.r.value = self.decay * self.r.value + x @ self.weight.value return self.r.value