Source code for brainpy_state._nest_network.nodeview
# Copyright 2026 BrainX Ecosystem Limited. Apache 2.0.
"""NodeView — NEST NodeCollection-style view over (population, local indices)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import jax.numpy as jnp
__all__ = ['NodeView']
def _flat_size(module) -> int:
"""Total number of elements along a population/device neuron dimension.
Mirrors ``brainpy_state._nest_network.projections._size`` (``in_size`` first,
falling back to ``varshape``) so the network layer agrees on one notion of
population size.
"""
sz = module.in_size if hasattr(module, 'in_size') else module.varshape
if isinstance(sz, tuple):
n = 1
for s in sz:
n *= int(s)
return n
return int(sz)
@dataclass(frozen=True)
class _Segment:
"""A contiguous reference into one population: ``(population, local idx)``."""
population: object
indices: jnp.ndarray # 1-D int array of local indices into the population
class NodeView:
"""A view over one or more slices of populations/devices (NEST-style).
Mimics NEST ``NodeCollection`` algebra: ``a + b`` concatenates two views
(preserving segment boundaries) and ``a[sl]`` slices a single-segment view.
Each segment references a population and a 1-D array of local indices into
that population's flattened neuron dimension.
"""
__module__ = 'brainpy.state'
def __init__(self, segments: List[_Segment]):
self._segments = list(segments)
[docs]
@classmethod
def of(cls, population) -> 'NodeView':
"""Build a full-population view over ``population``."""
return cls([_Segment(population, jnp.arange(_flat_size(population)))])
@property
def segments(self) -> List[_Segment]:
return self._segments
@property
def size(self) -> int:
return int(sum(int(s.indices.shape[0]) for s in self._segments))
def __add__(self, other: 'NodeView') -> 'NodeView':
if not isinstance(other, NodeView):
return NotImplemented
return NodeView(self._segments + other._segments)
def __getitem__(self, item) -> 'NodeView':
if len(self._segments) != 1:
raise NotImplementedError('slicing is supported on single-segment views only')
seg = self._segments[0]
idx = seg.indices[item]
idx = idx[None] if idx.ndim == 0 else idx
return NodeView([_Segment(seg.population, idx)])
def __len__(self) -> int:
return self.size