Source code for brainpy_state._nest_network.base

# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
from typing import Dict

import brainstate
import brainunit as u

from brainpy_state._base import Neuron
from brainpy_state._brainpy.projection import Projection
from brainpy_state._nest_base.base import NESTDevice

__all__ = ['Network']


class Network(brainstate.nn.Module):
    """brainpy.state network base class.

    Subclass and define populations, projections, and devices as
    attributes. ``update()`` walks the immediate module-tree children
    in projection-first then dynamics order.
    """
    __module__ = 'brainpy.state'

    def update(self, t=None) -> None:
        # Depth-1 traversal — matches the existing brainpy.state convention
        # documented in brainpy_state/_brainpy/projection.py:46-51. Networks
        # that need nested projection chains (Projection containing
        # Projection) currently must override update() explicitly; this is
        # tracked as an open question in the design spec §12.
        children = self.nodes(allowed_hierarchy=(1, 1))
        projections = [m for m in children.values() if isinstance(m, Projection)]
        others = [m for m in children.values() if not isinstance(m, Projection)]
        for m in projections:
            m()
        for m in others:
            m()

    @property
    def populations(self) -> Dict[str, Neuron]:
        return {k[-1]: m for k, m in self.nodes(allowed_hierarchy=(1, 1)).items()
                if isinstance(m, Neuron)}

    @property
    def projections(self) -> Dict[str, Projection]:
        return {k[-1]: m for k, m in self.nodes(allowed_hierarchy=(1, 1)).items()
                if isinstance(m, Projection)}

    @property
    def devices(self) -> Dict[str, NESTDevice]:
        return {k[-1]: m for k, m in self.nodes(allowed_hierarchy=(1, 1)).items()
                if isinstance(m, NESTDevice)}

[docs] def simulate(self, duration, *, dt=None, monitor=None) -> dict: """Run the network for ``duration``. Wraps ``brainstate.transform.for_loop`` over ``self.update``. Parameters ---------- duration : brainunit.Quantity Wall-clock time to simulate. dt : brainunit.Quantity, optional Timestep override. Defaults to ``brainstate.environ.get('dt')``. monitor : list[str] | dict[str, Callable] | None Per-step recording specification: - ``None`` (default) — no monitoring; the function returns ``{}``. - list of dotted attribute paths (e.g. ``['exc.spike']``) — stacked across time and returned under that key. - dict of ``name -> callable(net)``: callable evaluated each step, result stacked under ``name``. """ import brainstate.transform as transform if dt is None: dt = brainstate.environ.get('dt') if dt is None: raise ValueError( 'dt must be set via brainstate.environ.set(dt=...) or ' 'passed explicitly as simulate(..., dt=...)' ) times = u.math.arange(0.0 * u.get_unit(dt), duration, dt) indices = u.math.arange(times.size) # Normalize monitor to a dict of callables. callables = {} if monitor is None: pass elif isinstance(monitor, (list, tuple)): for path in monitor: callables[path] = self._make_path_callable(path) elif isinstance(monitor, dict): for name, fn in monitor.items(): if not callable(fn): raise TypeError( f'monitor dict value for {name!r} must be callable, ' f'got {type(fn).__name__}' ) callables[name] = fn else: raise TypeError( f'monitor must be list, dict, or None, got {type(monitor).__name__}' ) def step(t, i): with brainstate.environ.context(t=t, i=i): self.update(t) if callables: return {k: fn(self) for k, fn in callables.items()} return None if callables: stacked = transform.for_loop(step, times, indices) return dict(stacked) transform.for_loop(step, times, indices) return {}
def _make_path_callable(self, path: str): parts = path.split('.') def fn(net): obj = net for p in parts: obj = getattr(obj, p) if hasattr(obj, 'value'): obj = obj.value return obj return fn