# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass, field
from typing import Any, Mapping, TYPE_CHECKING
import numpy as np
from brainstate.typing import ArrayLike
from .config import (
alpha_for_2d as _alpha_for_2d,
alpha_for_2d_line as _alpha_for_2d_line,
alpha_for_2d_poly as _alpha_for_2d_poly,
alpha_for_3d_tube as _alpha_for_3d_tube,
color_for_2d_branch_type as _color_for_2d_branch_type,
color_for_branch_type as _color_for_branch_type,
edge_color_for_2d_branch_type as _edge_color_for_2d_branch_type,
frustum_edge_linewidth_2d as _frustum_edge_linewidth_2d,
)
if TYPE_CHECKING:
from braincell.filter import LocsetMask, RegionMask
from braincell import Morphology
# ---------------------------------------------------------------------------
# Value spec — styling for color-by-values overlays (M6 Phase 3)
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class ValueSpec:
"""Styling parameters for a color-by-values overlay.
Parameters
----------
values : ArrayLike
Per-element scalar array. The number of elements is interpreted
against the morphology:
* ``n_branches`` — one scalar per branch (the whole branch is
shaded with that single colour);
* total segment count — one scalar per segment (matplotlib
``LineCollection`` / ``PolyCollection`` uses the corresponding
per-segment colour);
* total centerline-point count (``segments + branches``) — one
scalar per polyline point (used directly as
``polydata.point_data``).
Values may carry ``brainunit`` units; the units are used to
generate a default colourbar label when ``unit_label`` is unset.
cmap : str
Matplotlib colormap name, forwarded to
:class:`matplotlib.cm.ScalarMappable` and PyVista ``add_mesh``.
vmin, vmax : float or None
Fixed colour-scale bounds. When either is ``None`` the missing
bound is derived from the data range.
norm : object or None
Optional matplotlib ``Normalize``-compatible object. Takes
precedence over *vmin*/*vmax* for the 2D backend.
label : str or None
Colourbar title. When ``None`` no title is drawn.
unit_label : str or None
Optional unit string appended to *label* on the colourbar.
show_colorbar : bool
Whether to render a colourbar alongside the scene (matplotlib
only; PyVista always draws its own scalar bar when scalars are
present).
"""
values: ArrayLike
cmap: str = "viridis"
vmin: float | None = None
vmax: float | None = None
norm: Any | None = None
label: str | None = None
unit_label: str | None = None
show_colorbar: bool = True
# ---------------------------------------------------------------------------
# Resolved per-branch value arrays
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class BranchValues:
"""Per-branch centerline-point scalar array.
The scene builders use :class:`BranchValues` as the canonical in-scene
representation of a color-by-values request. One ``BranchValues``
entry exists per branch; ``point_values`` stores a scalar for every
centerline point (one more than the number of segments).
"""
branch_index: int
point_values: np.ndarray # shape (n_points,)
@property
def segment_values(self) -> np.ndarray:
"""Per-segment midpoint scalars: ``0.5 * (v[i] + v[i+1])``."""
if self.point_values.size <= 1:
return self.point_values.copy()
return 0.5 * (self.point_values[:-1] + self.point_values[1:])
def color_for_branch_type(branch_type: str) -> tuple[int, int, int]:
return _color_for_branch_type(branch_type)
def color_for_2d_branch_type(branch_type: str) -> tuple[int, int, int]:
return _color_for_2d_branch_type(branch_type)
def edge_color_for_2d_branch_type(branch_type: str) -> tuple[int, int, int]:
return _edge_color_for_2d_branch_type(branch_type)
def alpha_for_2d() -> float:
return _alpha_for_2d()
def alpha_for_2d_line() -> float:
return _alpha_for_2d_line()
def alpha_for_2d_poly() -> float:
return _alpha_for_2d_poly()
def frustum_edge_linewidth_2d() -> float:
return _frustum_edge_linewidth_2d()
def alpha_for_3d_tube() -> float:
return _alpha_for_3d_tube()
# ---------------------------------------------------------------------------
# Overlay input spec (what the user passes to plot2d / plot3d)
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class OverlaySpec:
"""User-facing overlay request passed through ``plot2d`` / ``plot3d``.
The scene builders translate this into concrete overlay *primitives*
(``HighlightStroke2D`` / ``Marker2D`` / ``HighlightStroke3D`` /
``Marker3D``) that the backends then render on top of the base scene.
Fields are plain masks so callers can build them with
``region_expr.evaluate(morpho)`` / ``locset_expr.evaluate(morpho)``.
``values`` may be either a bare array (interpreted with default
styling) or a :class:`ValueSpec` carrying colormap / bounds / label
information.
"""
region: "RegionMask | None" = None
locset: "LocsetMask | None" = None
values: "ValueSpec | ArrayLike | None" = None
[docs]
def values_spec(self) -> "ValueSpec | None":
"""Return the normalized :class:`ValueSpec`, or ``None``.
Accepts either a bare array (upgraded to ``ValueSpec`` with
default styling) or an already-constructed :class:`ValueSpec`.
"""
if self.values is None:
return None
if isinstance(self.values, ValueSpec):
return self.values
return ValueSpec(values=self.values)
# ---------------------------------------------------------------------------
# 3D scene primitives
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class BranchPolyline3D:
branch_index: int
branch_name: str
branch_type: str
points_um: np.ndarray
radii_um: np.ndarray
@dataclass(frozen=True)
class BranchTypeBatch3D:
branch_type: str
color_rgb: tuple[int, int, int]
opacity: float
branch_indices: tuple[int, ...]
branch_names: tuple[str, ...]
points_um: np.ndarray
radii_um: np.ndarray
lines: np.ndarray
@dataclass(frozen=True)
class ValueBatch3D:
"""Scalar-valued PolyData batch for a color-by-values 3D scene.
Mirrors :class:`BranchTypeBatch3D` but carries a per-point scalar
array consumed by ``PyVista.add_mesh(scalars=...)``. One
``ValueBatch3D`` is emitted per branch type so that the batch's
geometry can still be grouped the same way the base renderer does.
"""
branch_type: str
branch_indices: tuple[int, ...]
branch_names: tuple[str, ...]
points_um: np.ndarray
radii_um: np.ndarray
lines: np.ndarray
point_values: np.ndarray # shape (n_points,)
opacity: float
@dataclass(frozen=True)
class HighlightStroke3D:
"""Polyline fragment emitted for a region interval overlay in 3D.
The backend renders this as an accent-colored stroke on top of the
base tube/skeleton for the affected branch.
"""
branch_index: int
branch_name: str
branch_type: str
points_um: np.ndarray
radii_um: np.ndarray
color_rgb: tuple[int, int, int]
opacity: float = 1.0
@dataclass(frozen=True)
class Marker3D:
"""Scatter marker emitted from a locset point in 3D."""
branch_index: int
x: float
position_um: np.ndarray
color_rgb: tuple[int, int, int]
radius_um: float = 1.5
# ---------------------------------------------------------------------------
# 2D scene primitives
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Polyline2D:
branch_index: int
branch_name: str
branch_type: str
points_um: np.ndarray
widths_um: np.ndarray
color_rgb: tuple[int, int, int]
alpha: float = 1.0
draw_order: int = 0
@dataclass(frozen=True)
class Polygon2D:
branch_index: int
branch_name: str
branch_type: str
points_um: np.ndarray
color_rgb: tuple[int, int, int]
edge_color_rgb: tuple[int, int, int] | None = None
edge_linewidth: float = 1.0
alpha: float = 1.0
draw_order: int = 0
@dataclass(frozen=True)
class Circle2D:
center_um: np.ndarray
radius_um: float
color_rgb: tuple[int, int, int]
draw_order: int = 0
@dataclass(frozen=True)
class Label2D:
text: str
position_um: np.ndarray
color_rgb: tuple[int, int, int] = (0, 0, 0)
draw_order: int = 0
@dataclass(frozen=True)
class PolylineValues2D:
"""Per-segment scalar-valued polyline for a single branch in 2D.
Emitted by the scene builder when the caller supplies ``values=``
and ``shape='line'``. The matplotlib backend vectorizes rendering
via :class:`matplotlib.collections.LineCollection`, with one
segment drawn per consecutive pair of points and one scalar per
segment from :attr:`segment_values`.
"""
branch_index: int
branch_name: str
branch_type: str
points_um: np.ndarray # shape (n_points, 2)
segment_values: np.ndarray # shape (n_points - 1,)
widths_um: np.ndarray # shape (n_points,) — per-point centerline diameter
draw_order: int = 0
@dataclass(frozen=True)
class PolygonValuesBatch2D:
"""Batched scalar-valued quad polygons for a single branch (frustum).
For ``shape='frustum'`` each segment is drawn as a trapezoid with
a per-polygon scalar; using a batched primitive lets the
matplotlib backend materialise the whole branch as a single
:class:`matplotlib.collections.PolyCollection`.
"""
branch_index: int
branch_name: str
branch_type: str
polygons_um: np.ndarray # shape (n_segments, 4, 2)
polygon_values: np.ndarray # shape (n_segments,)
edge_color_rgb: tuple[int, int, int] | None = None
edge_linewidth: float = 0.0
draw_order: int = 0
@dataclass(frozen=True)
class HighlightStroke2D:
"""Polyline fragment emitted for a region interval overlay in 2D."""
branch_index: int
branch_name: str
branch_type: str
points_um: np.ndarray
color_rgb: tuple[int, int, int]
linewidth: float
alpha: float = 1.0
draw_order: int = 0
@dataclass(frozen=True)
class Marker2D:
"""Scatter marker emitted from a locset point in 2D."""
branch_index: int
x: float
position_um: np.ndarray
color_rgb: tuple[int, int, int]
size: float = 30.0
draw_order: int = 0
# ---------------------------------------------------------------------------
# Scene containers
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class RenderScene3D:
branches: tuple[BranchPolyline3D, ...]
batches: tuple[BranchTypeBatch3D, ...]
highlight_strokes: tuple[HighlightStroke3D, ...] = ()
markers: tuple[Marker3D, ...] = ()
value_batches: tuple[ValueBatch3D, ...] = ()
value_spec: ValueSpec | None = None
mode: str = "geometry"
@dataclass(frozen=True)
class RenderScene2D:
polylines: tuple[Polyline2D, ...] = ()
polygons: tuple[Polygon2D, ...] = ()
circles: tuple[Circle2D, ...] = ()
labels: tuple[Label2D, ...] = ()
highlight_strokes: tuple[HighlightStroke2D, ...] = ()
markers: tuple[Marker2D, ...] = ()
polyline_values: tuple[PolylineValues2D, ...] = ()
polygon_value_batches: tuple[PolygonValuesBatch2D, ...] = ()
value_spec: ValueSpec | None = None
draw_order: tuple[int, ...] = ()
projection_plane: str | None = None
layout: str = "projected"
shape: str = "line"
# ---------------------------------------------------------------------------
# Render request — neutral schema with a backend_options escape hatch
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class RenderRequest:
"""Dispatched to a backend's ``render`` method.
Backend-specific parameters (matplotlib ``ax``, pyvista ``notebook``,
``jupyter_backend``, ``return_plotter``) live in ``backend_options``
so that adding a new backend does not require editing this schema.
"""
morpho: "Morphology"
scene: RenderScene2D | RenderScene3D | None = None
overlay: OverlaySpec = field(default_factory=OverlaySpec)
dimensionality: str = "3d"
mode: str | None = None
layout: str | None = None
shape: str | None = None
backend_options: Mapping[str, Any] = field(default_factory=dict)