Source code for brainpy_state._brainpy.projection

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable, Union
from typing import Optional

import brainevent
import brainstate
from brainstate import State
from brainstate.mixin import JointTypes, ParamDescriber
from brainstate.nn import init_maybe_prefetch
from brainstate.util import get_unique_name

from brainpy_state._base import Dynamics
from brainpy_state._mixin import BindCondData, AlignPost
from .synouts import SynOut

__all__ = [
    'Projection',
    'AlignPostProj',
    'DeltaProj',
    'CurrentProj',
    'align_pre_projection',
    'align_post_projection',
]


class Projection(brainstate.nn.Module):
    r"""
    Base class for synaptic projection modules.

    A projection connects pre-synaptic and post-synaptic neural populations,
    handling synaptic transmission, weight application, and input delivery.
    In the BrainState execution order, projections are updated *before*
    dynamics modules, following the natural information flow: projections
    process inputs first, then neurons integrate.

    Parameters
    ----------
    *args
        Positional arguments forwarded to :class:`brainstate.nn.Module`.
    **kwargs
        Keyword arguments forwarded to :class:`brainstate.nn.Module`.

    Raises
    ------
    ValueError
        If :meth:`update` is called but no child nodes are defined.

    See Also
    --------
    AlignPostProj : Post-synaptic alignment projection.
    DeltaProj : Delta-input projection (direct voltage changes).
    CurrentProj : Current-based projection.
    align_pre_projection : Pre-synaptic alignment convenience wrapper.
    align_post_projection : Post-synaptic alignment convenience wrapper.

    Notes
    -----
    Subclasses typically compose a communication module (connection weights),
    a synapse model, and a synaptic output module. The base class delegates
    its :meth:`update` call to child nodes registered via the module tree.
    """
    __module__ = 'brainpy.state'

    def update(self, *args, **kwargs):
        sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values())
        if len(sub_nodes):
            for node in sub_nodes:
                node(*args, **kwargs)
        else:
            raise ValueError('Do not implement the update() function.')


def _check_modules(*modules):
    # checking modules
    for module in modules:
        if not callable(module) and not isinstance(module, State):
            raise TypeError(
                f'The module should be a callable function or a brainstate.State, but got {module}.'
            )
    return tuple(modules)


def call_module(module, *args, **kwargs):
    if callable(module):
        return module(*args, **kwargs)
    elif isinstance(module, State):
        return module.value
    else:
        raise TypeError(
            f'The module should be a callable function or a brainstate.State, but got {module}.'
        )


def is_instance(x, cls) -> bool:
    return isinstance(x, cls)


def get_post_repr(label, syn, out):
    if label is None:
        return f'{syn.identifier} // {out.identifier}'
    else:
        return f'{label}{syn.identifier} // {out.identifier}'


def align_post_add_bef_update(
    syn_desc: ParamDescriber[AlignPost],
    out_desc: ParamDescriber[BindCondData],
    post: Dynamics,
    proj_name: str,
    label: str,
):
    # synapse and output initialization
    _post_repr = get_post_repr(label, syn_desc, out_desc)
    if not post.has_before_update(_post_repr):
        syn_cls = syn_desc()
        out_cls = out_desc()

        # synapse and output initialization
        post.add_current_input(proj_name, out_cls, label=label)
        post.add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
    syn = post.get_before_update(_post_repr).syn
    out = post.get_before_update(_post_repr).out
    return syn, out


class _AlignPost(brainstate.nn.Module):
    def __init__(
        self,
        syn: Dynamics,
        out: BindCondData
    ):
        super().__init__()
        self.syn = syn
        self.out = out

    def update(self, *args, **kwargs):
        self.out.bind_cond(self.syn(*args, **kwargs))


class AlignPostProj(Projection):
    r"""
    Post-synaptic alignment projection.

    In this projection pattern, the synapse dynamics and synaptic output
    are aligned with (owned by) the post-synaptic neuron. Multiple
    projections targeting the same post-synaptic population with the same
    synapse/output descriptor share a single synapse and output instance,
    enabling efficient event-driven updates.

    The update pipeline is:

    1. Optional pre-processing modules transform the input.
    2. The communication module (``comm``) maps pre-synaptic signals
       to post-synaptic space.
    3. The result is added as a delta input to the shared synapse.
    4. The synapse and output are updated by the post-synaptic neuron's
       ``before_update`` hook (if using descriptor merging).

    Parameters
    ----------
    *modules
        Optional pre-processing modules applied sequentially to the input
        before the communication step.
    comm : Callable
        Communication module (e.g., ``brainevent.nn.FixedProb``) that
        maps pre-synaptic activity to post-synaptic space.
    syn : ParamDescriber[AlignPost] or AlignPost
        Synapse model or its descriptor. When a descriptor is provided,
        the synapse is created lazily and shared across projections
        targeting the same post-synaptic neuron.
    out : ParamDescriber[SynOut] or SynOut
        Synaptic output module or its descriptor.
    post : Dynamics
        Post-synaptic neural population.
    label : str, optional
        Label for identifying this projection's contribution in the
        post-synaptic neuron's input dictionary.

    Raises
    ------
    TypeError
        If ``comm`` is not callable, if ``syn``/``out`` types are
        inconsistent, or if ``post`` is not a :class:`Dynamics` instance.

    See Also
    --------
    DeltaProj : Direct delta-input projection.
    CurrentProj : Current-based projection.
    align_post_projection : Convenience wrapper with spike generation.

    Notes
    -----
    - When both ``syn`` and ``out`` are descriptors (``ParamDescriber``),
      the projection attempts to merge with existing synapse/output
      instances on the post-synaptic neuron, avoiding duplicate state.
    - When ``syn`` is an already-instantiated ``AlignPost`` object, no
      merging occurs and ``out`` must also be an instantiated ``SynOut``.

    References
    ----------
    .. [1] Brette, R., et al. (2007). Simulation of networks of spiking
           neurons: a review of tools and strategies. Journal of
           Computational Neuroscience, 23(3), 349-398.

    Examples
    --------
    .. code-block:: python

        >>> import brainpy
        >>> import brainstate
        >>> import saiunit as u
        >>> n_pre, n_post = 800, 200
        >>> post_pop = brainpy.state.LIF(n_post, tau=20.*u.ms)
        >>> post_pop.init_state()
        >>> proj = brainpy.state.AlignPostProj(
        ...     comm=brainstate.nn.Linear(n_pre, n_post),
        ...     syn=brainpy.state.Expon.desc(n_post, tau=5.*u.ms),
        ...     out=brainpy.state.CUBA.desc(scale=u.volt),
        ...     post=post_pop,
        ... )
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        *modules,
        comm: Callable,
        syn: Union[ParamDescriber[AlignPost], AlignPost],
        out: Union[ParamDescriber[SynOut], SynOut],
        post: Dynamics,
        label: Optional[str] = None,
    ):
        super().__init__(name=get_unique_name(self.__class__.__name__))

        # checking modules
        self.modules = _check_modules(*modules)

        # checking communication model
        if not callable(comm):
            raise TypeError(
                f'The communication should be an instance of callable function, but got {comm}.'
            )

        # checking synapse and output models
        if is_instance(syn, ParamDescriber[AlignPost]):
            if not is_instance(out, ParamDescriber[SynOut]):
                if is_instance(out, ParamDescriber):
                    raise TypeError(
                        f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
                        f'the synapse is an instance of {AlignPost}, but got {out}.'
                    )
                raise TypeError(
                    f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
                    f'the synapse is a describer, but we got {out}.'
                )
            merging = True
        else:
            if is_instance(syn, ParamDescriber):
                raise TypeError(
                    f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
                )
            if not is_instance(out, SynOut):
                raise TypeError(
                    f'The output should be an instance of {SynOut} when the synapse is '
                    f'not a describer, but we got {out}.'
                )
            merging = False
        self.merging = merging

        # checking post model
        if not is_instance(post, Dynamics):
            raise TypeError(
                f'The post should be an instance of {Dynamics}, but got {post}.'
            )

        if merging:
            # synapse and output initialization
            syn, out = align_post_add_bef_update(syn_desc=syn,
                                                 out_desc=out,
                                                 post=post,
                                                 proj_name=self.name,
                                                 label=label)
        else:
            post.add_current_input(self.name, out)

        # references
        self.comm = comm
        self.syn: JointTypes[Dynamics, AlignPost] = syn
        self.out: BindCondData = out
        self.post: Dynamics = post

[docs] @brainstate.nn.call_order(2) def init_state(self, *args, **kwargs): for module in self.modules: init_maybe_prefetch(module, *args, **kwargs)
def update(self, *args): # call all modules for module in self.modules: x = call_module(module, *args) args = (x,) # communication module x = self.comm(*args) # add synapse input self.syn.add_delta_input(self.name, x) if not self.merging: # synapse and output interaction conductance = self.syn() self.out.bind_cond(conductance) class DeltaProj(Projection): r""" Delta-input projection. Applies pre-synaptic signals directly as delta (voltage) inputs to the post-synaptic population, bypassing synapse dynamics entirely. Useful for instantaneous coupling or simplified network models. The update pipeline is: 1. Optional pre-fetch modules transform the input. 2. The communication module (``comm``) maps the signal to post-synaptic space. 3. The result is added as a delta input to the post-synaptic neuron's membrane potential. Parameters ---------- *prefetch Optional modules applied sequentially to the input before the communication step. comm : Callable Communication module mapping pre-synaptic activity to post-synaptic delta inputs. post : Dynamics Post-synaptic neural population. label : str, optional Label for identifying this projection's delta input. Raises ------ TypeError If ``comm`` is not callable or ``post`` is not a :class:`Dynamics` instance. See Also -------- AlignPostProj : Projection with full synapse dynamics. CurrentProj : Current-based projection. Notes ----- - Delta projections add directly to the membrane potential rather than to the current, making them suitable for modeling gap junctions or abstract coupling. - The ``label`` parameter allows multiple delta projections to coexist on the same post-synaptic population with distinct labels. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> n_neurons = 100 >>> pop = brainpy.state.LIF(n_neurons, tau=10.*u.ms) >>> pop.init_state() >>> delta_proj = brainpy.state.DeltaProj( ... comm=brainstate.nn.Linear(n_neurons, n_neurons), ... post=pop, ... ) """ __module__ = 'brainpy.state' def __init__( self, *prefetch, comm: Callable, post: Dynamics, label=None, ): super().__init__(name=get_unique_name(self.__class__.__name__)) self.label = label # checking modules self.prefetches = _check_modules(*prefetch) # checking communication model if not callable(comm): raise TypeError( f'The communication should be an instance of callable function, but got {comm}.' ) self.comm = comm # post model if not isinstance(post, Dynamics): raise TypeError( f'The post should be an instance of {Dynamics}, but got {post}.' ) self.post = post
[docs] @brainstate.nn.call_order(2) def init_state(self, *args, **kwargs): for prefetch in self.prefetches: init_maybe_prefetch(prefetch, *args, **kwargs)
def update(self, *x): for module in self.prefetches: x = (call_module(module, *x),) assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.' x = self.comm(x[0]) self.post.add_delta_input(self.name, x, label=self.label) class CurrentProj(Projection): r""" Current-based projection. Delivers current input to post-synaptic neurons by passing the communication output through a synaptic output module (e.g., :class:`COBA` or :class:`CUBA`) and registering it as a current input on the post-synaptic population. The update pipeline is: 1. Optional pre-fetch modules transform the input. 2. The communication module (``comm``) maps pre-synaptic activity. 3. The synaptic output module (``out``) converts conductance to current and binds it to the post-synaptic neuron. Parameters ---------- *prefetch Optional pre-fetch modules. If provided, the last element must be a :class:`brainstate.nn.Prefetch` or :class:`brainstate.nn.PrefetchDelayAt` instance. comm : Callable Communication module mapping pre-synaptic activity. out : SynOut Synaptic output module that converts the communication result to post-synaptic current. post : Dynamics Post-synaptic neural population. Raises ------ TypeError If ``comm`` is not callable, ``out`` is not a :class:`SynOut`, ``post`` is not a :class:`Dynamics`, or the last prefetch module has an incorrect type. See Also -------- AlignPostProj : Projection with aligned synapse dynamics. DeltaProj : Direct delta-input projection. align_pre_projection : Convenience wrapper with pre-synaptic alignment. Notes ----- - The output module is immediately registered as a current input on the post-synaptic population at construction time. - Unlike :class:`AlignPostProj`, this projection does not use synapse dynamics -- the communication result is directly converted to current. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> n_neurons = 100 >>> pop = brainpy.state.LIF(n_neurons, tau=10.*u.ms) >>> pop.init_state() >>> proj = brainpy.state.CurrentProj( ... comm=brainstate.nn.Linear(n_neurons, n_neurons), ... out=brainpy.state.CUBA(scale=u.volt), ... post=pop, ... ) """ __module__ = 'brainpy.state' def __init__( self, *prefetch, comm: Callable, out: SynOut, post: Dynamics, ): super().__init__(name=get_unique_name(self.__class__.__name__)) # check prefetch self.prefetch = prefetch if len(self.prefetch) > 0 and not isinstance( prefetch[-1], (brainstate.nn.Prefetch, brainstate.nn.PrefetchDelayAt) ): raise TypeError( f'The last element of prefetch should be an instance ' f'of {brainstate.nn.Prefetch} or {brainstate.nn.PrefetchDelayAt}, ' f'but got {prefetch[-1]}.' ) # check out if not isinstance(out, SynOut): raise TypeError(f'The out should be a SynOut, but got {out}.') self.out = out # check post if not isinstance(post, Dynamics): raise TypeError(f'The post should be a Dynamics, but got {post}.') self.post = post post.add_current_input(self.name, out) # output initialization self.comm = comm
[docs] @brainstate.nn.call_order(2) def init_state(self, *args, **kwargs): for prefetch in self.prefetch: init_maybe_prefetch(prefetch, *args, **kwargs)
def update(self, *x): for prefetch in self.prefetch: x = (call_module(prefetch, *x),) x = self.comm(*x) self.out.bind_cond(x) class align_pre_projection(Projection): r""" Pre-synaptic alignment projection with spike generation. A convenience wrapper that combines spike generation, optional short-term plasticity (STP), pre-synaptic synapse dynamics, and a :class:`CurrentProj` into a single module. The synapse operates in pre-synaptic space, processing spikes before the communication step transmits them to post-synaptic neurons. The update pipeline is: 1. Spike generator modules produce binary spike signals. 2. If STP is provided, spikes are modulated by short-term plasticity dynamics. 3. The pre-synaptic synapse model filters the (modulated) spikes. 4. The :class:`CurrentProj` maps the filtered signal to post-synaptic current. Parameters ---------- *spike_generator One or more modules that produce spike signals from the input. syn : Dynamics Pre-synaptic synapse dynamics (e.g., :class:`Expon`). comm : Callable Communication module mapping pre-synaptic to post-synaptic space. out : SynOut Synaptic output module. post : Dynamics Post-synaptic neural population. stp : Dynamics, optional Short-term plasticity module applied after spike generation. See Also -------- align_post_projection : Post-synaptic alignment variant. CurrentProj : Underlying current projection used internally. AlignPostProj : Alternative projection with post-synaptic alignment. Notes ----- - Pre-synaptic alignment means the synapse state lives in pre-synaptic space. This is natural for models where synaptic filtering (e.g., exponential decay) should happen before the communication step. - Spike signals are wrapped in ``brainevent.BinaryArray`` for efficient event-driven processing. - When STP is used, its output is wrapped in ``brainevent.MaskedFloat`` to preserve sparsity. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> pre = brainpy.state.LIF(800, tau=20.*u.ms) >>> post = brainpy.state.LIF(200, tau=20.*u.ms) >>> pre.init_state() >>> post.init_state() >>> proj = brainpy.state.align_pre_projection( ... pre, ... syn=brainpy.state.Expon(800, tau=5.*u.ms), ... comm=brainstate.nn.Linear(800, 200), ... out=brainpy.state.CUBA(scale=u.volt), ... post=post, ... ) """ __module__ = 'brainpy.state' def __init__( self, *spike_generator, syn: Dynamics, comm: Callable, out: SynOut, post: Dynamics, stp: Dynamics = None, ): super().__init__() self.spike_generator = _check_modules(*spike_generator) self.projection = CurrentProj(comm=comm, out=out, post=post) self.syn = syn self.stp = stp
[docs] @brainstate.nn.call_order(2) def init_state(self, *args, **kwargs): for module in self.spike_generator: init_maybe_prefetch(module, *args, **kwargs)
def update(self, *x): for fun in self.spike_generator: x = fun(*x) if isinstance(x, (tuple, list)): x = tuple(x) else: x = (x,) assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values" x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation if self.stp is not None: x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat x = self.syn(x) # Apply pre-synaptic alignment return self.projection(x) class align_post_projection(Projection): r""" Post-synaptic alignment projection with spike generation. A convenience wrapper that combines spike generation, optional short-term plasticity (STP), and an :class:`AlignPostProj` into a single module. The synapse operates in post-synaptic space, sharing state across projections that target the same post-synaptic neuron. The update pipeline is: 1. Spike generator modules produce binary spike signals. 2. If STP is provided, spikes are modulated by short-term plasticity dynamics. 3. The :class:`AlignPostProj` handles communication and post-aligned synapse/output updates. Parameters ---------- *spike_generator One or more modules that produce spike signals from the input. comm : Callable Communication module mapping pre-synaptic to post-synaptic space. syn : AlignPost or ParamDescriber[AlignPost] Post-synaptic synapse model or its descriptor. out : SynOut or ParamDescriber[SynOut] Synaptic output module or its descriptor. post : Dynamics Post-synaptic neural population. stp : Dynamics, optional Short-term plasticity module applied after spike generation. See Also -------- align_pre_projection : Pre-synaptic alignment variant. AlignPostProj : Underlying post-aligned projection used internally. Notes ----- - Post-synaptic alignment enables synapse state sharing: if multiple projections target the same post-synaptic population with the same synapse/output descriptor, they share a single synapse instance. - Spike signals are wrapped in ``brainevent.BinaryArray`` for efficient event-driven processing. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> pre = brainpy.state.LIF(800, tau=20.*u.ms) >>> post = brainpy.state.LIF(200, tau=20.*u.ms) >>> pre.init_state() >>> post.init_state() >>> proj = brainpy.state.align_post_projection( ... pre, ... comm=brainstate.nn.Linear(800, 200), ... syn=brainpy.state.Expon.desc(200, tau=5.*u.ms), ... out=brainpy.state.CUBA.desc(scale=u.volt), ... post=post, ... ) """ __module__ = 'brainpy.state' def __init__( self, *spike_generator, comm: Callable, syn: Union[AlignPost, ParamDescriber[AlignPost]], out: Union[SynOut, ParamDescriber[SynOut]], post: Dynamics, stp: Dynamics = None, ): super().__init__() self.spike_generator = _check_modules(*spike_generator) self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post) self.stp = stp
[docs] @brainstate.nn.call_order(2) def init_state(self, *args, **kwargs): for module in self.spike_generator: init_maybe_prefetch(module, *args, **kwargs)
def update(self, *x): for fun in self.spike_generator: x = fun(*x) if isinstance(x, (tuple, list)): x = tuple(x) else: x = (x,) assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values" x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation if self.stp is not None: x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat return self.projection(x)