Orchestration#

The orchestration layer sits on top of the neural mass models. It provides the reusable run loop, connectome builder, and loss/score builders that earlier had to be hand-written in every example and tutorial:

  • Network wires a node model into a delay-coupled whole-brain network.

  • Simulator drives any model (single node or whole-brain network) and collects monitored trajectories into a unit-aware result.

  • brainmass.objectives composes braintools.metric into small, jit / grad / vmap-safe objective callables over those trajectories.

  • Fitter fits a model’s trainable parameters to data behind one .fit call, swapping between gradient (optax), Nevergrad, and SciPy backends without rewriting the objective.

Network#

Network

Wire a node model into a delay-coupled whole-brain network.

class brainmass.Network(*args, **kwargs)#

Wire a node model into a delay-coupled whole-brain network.

The network reads a single node (a Dynamics sized for N regions) and couples its regions through a structural-connectivity matrix, optionally with distance-dependent conduction delays. update computes the coupling current and feeds it back as the node’s first positional input – the same current = coupling(); node(current) idiom hand-written in the examples.

nodebrainstate.nn.Dynamics

The per-region dynamics, already sized for N regions (node.varshape[0] == N) and carrying any per-region noise. Exposed afterwards as node; its states are reachable as network.node.<coupled_var>.

connarray_like or brainstate.nn.Param

Structural connectivity. A plain (N, N) (or flattened (N * N,)) array has its diagonal zeroed unless self_connection is True; a Param / LaplacianConnParam is passed through untouched (the caller owns its structure) and may be trained.

distancearray_like, optional

Inter-region distance matrix (N, N). Combined with speed into conduction delays distance / speed. If either distance or speed is None the coupling is instantaneous (zero delays).

speedfloat or brainunit.Quantity, optional

Conduction speed. If distance carries length units and speed carries length / time units the delay is unit-correct; if both are plain numbers the quotient is interpreted as milliseconds (matching the examples).

coupling : {‘diffusive’, ‘additive’, ‘laplacian’, ‘sigmoidal’, ‘tanh’,

‘sigmoidal_jansen_rit’}, default ‘diffusive’

Coupling kernel. 'diffusive' uses DiffusiveCoupling (k * sum_j conn_ij (x_j - x_i)); 'additive' uses AdditiveCoupling (k * sum_j conn_ij x_j); 'laplacian' wraps conn in a LaplacianConnParam and applies it additively. The nonlinear forms 'sigmoidal' (SigmoidalCoupling), 'tanh' (HyperbolicTangentCoupling) and 'sigmoidal_jansen_rit' (SigmoidalJansenRitCoupling) apply their nonlinearity to the same delayed source read, with k as the global strength (TVB G; G k).

coupled_varstr

Name of the node state variable to couple (e.g. 'rE', 'x', 'V'). Validated at initialisation; an unknown name raises ValueError.

kfloat, brainunit.Quantity or brainstate.nn.Param, default 1.0

Global coupling strength. Pass a trainable Param to fit it.

delay_initCallable, default braintools.init.Uniform(0., 0.05)

Initializer for the delay buffer’s history.

self_connectionbool, default False

If False (default), the connectivity diagonal is zeroed (no self-coupling). Only applies to plain-array conn.

noisebrainstate.nn.Module, optional

Optional network-level noise process (e.g. an OUProcess sized for N). Its output is added to the coupling current each step.

brainmass.Simulator : drives the network and collects trajectories. brainmass.DiffusiveCoupling : the underlying diffusive coupling kernel.

The delay-buffer shape (max_delay_steps, N) does not depend on whether the (delay, index) pair is supplied flattened (as in examples/100) or as (N, N) matrices (as here), so a seeded run reproduces the hand-wired examples bit-for-bit. The self-delay (delay-matrix diagonal) is always zeroed.

A four-region Hopf network driven by the simulator:

>>> import brainmass
>>> import brainunit as u
>>> import numpy as np
>>> N = 4
>>> conn = np.ones((N, N)) * 0.1
>>> node = brainmass.HopfStep(N, a=0.1)
>>> net = brainmass.Network(node, conn=conn, coupled_var='x', k=0.5)
>>> sim = brainmass.Simulator(net, dt=0.1 * u.ms)
>>> res = sim.run(5.0 * u.ms, monitors=lambda m: m.node.x.value)
>>> res['output'].shape
(50, 4)
Return type:

Any

__init__(node, *, conn, distance=None, speed=None, coupling='diffusive', coupled_var, k=1.0, delay_init=Uniform(low=0.0, high=0.05), self_connection=False, noise=None)[source]#
init_state(*args, **kwargs)[source]#

Validate coupled_var once the node’s states exist.

update(*node_inputs)[source]#

Advance one step: coupling current (+ noise) -> node.

Parameters:

*node_inputs – Extra inputs forwarded to the node after the coupling current (e.g. a second drive supplied by brainmass.Simulator.run()). The coupling current is always the node’s first positional input.

Returns:

The node’s update return value.

Return type:

Any

Simulator#

Simulator

Drive a Dynamics / Module and collect monitored trajectories.

class brainmass.Simulator(model, dt=None)#

Drive a Dynamics / Module and collect monitored trajectories.

The simulator wraps the standard brainmass run loop – set dt, initialise states, step the model inside environ.context(i=, t=) with brainstate.transform.for_loop(), and stack the recorded values – in a single run() call. It reimplements none of that machinery; it only composes it.

Parameters:
  • model (brainstate.nn.Dynamics or brainstate.nn.Module) – The model to drive. Exposed afterwards as model. May be a single node or a network whose update reads its own internal coupling.

  • dt (brainunit.Quantity, optional) – Integration time step. If None (default), the step is read from the global environment (brainstate.environ.get('dt')) at run() time; if that is also unset, run() raises ValueError.

See also

brainmass.objectives

composable loss builders over simulated trajectories.

Notes

The result of run() is a plain dict (a valid JAX pytree, so it is safe to return through jit / grad / vmap) mapping each monitor name to its stacked trajectory, plus a 'ts' time axis. ts[k] is the simulation time at the end of the k-th recorded step (the recorded value is the post-update state).

Examples

>>> import brainmass
>>> import brainunit as u
>>> node = brainmass.HopfStep(2, a=-0.2)
>>> sim = brainmass.Simulator(node, dt=0.1 * u.ms)
>>> res = sim.run(10.0 * u.ms, monitors=['x'])
>>> res['x'].shape
(100, 2)
>>> res['ts'].shape
(100,)

A derived observable (here E - I) is monitored with a callable:

>>> jr = brainmass.JansenRitStep(in_size=1)
>>> sim = brainmass.Simulator(jr, dt=0.1 * u.ms)
>>> res = sim.run(5.0 * u.ms, monitors=lambda m: m.eeg(), transient=1.0 * u.ms)
>>> res['output'].shape
(40, 1)
__init__(model, dt=None)[source]#
run(duration, *, inputs=None, monitors=None, transient=None, sample_every=None, batch_size=None, init_states=True, jit=True)[source]#

Simulate model for duration and return the recorded trajectories.

Parameters:
  • duration (brainunit.Quantity) – Total simulated time. The number of integration steps is int(duration / dt); a non-integer multiple is floored with a warning.

  • inputs (None, array-like or callable, optional) –

    External drive forwarded to model.update:

    • None (default): call model.update() with no arguments.

    • array-like of shape (n_steps, ...): row i is passed as the single argument model.update(inputs[i]).

    • callable(i, t): its return is splatted – model.update(*r) if it returns a tuple, else model.update(r).

  • monitors (None, list of str, callable or dict, optional) –

    What to record each step:

    • None (default): the return value of update(), under key 'output'.

    • list of state-attribute names (e.g. ['rE']): each getattr(model, name).value.

    • callable(model) -> value: a derived observable, under key 'output' (e.g. lambda m: m.eeg()).

    • dict mapping output name to a state-attribute name or a callable(model).

  • transient (None, brainunit.Quantity or int, optional) – Leading transient to discard from the outputs, as a duration or a step count. Must be shorter than the run.

  • sample_every (None or int, optional) – Record every k-th step (output downsampling; generalises the Jansen-Rit TR loop). None records every step.

  • batch_size (None or int, optional) – If given, initialise states with this batch size (batched initial conditions); outputs gain a leading batch axis.

  • init_states (bool, default True) – Call init_all_states before running. Set False to continue from the model’s current state.

  • jit (bool, default True) – Compile the run with brainstate.transform.jit().

Returns:

Maps each monitor name (or 'output') to its stacked trajectory and 'ts' to the time axis. Each trajectory has a leading axis of length (n_steps - transient) // sample_every.

Return type:

dict

Raises:

ValueError – If dt is unset, transient is not shorter than the run, a monitored name is not a model attribute, sample_every < 1, or the inputs length does not match the step count.

Fitter#

Fitter

Fit a model's trainable parameters to data behind one .fit call.

FitResult

Outcome of a Fitter.fit() call.

class brainmass.Fitter(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)#

Fit a model’s trainable parameters to data behind one .fit call.

Parameters:
  • model (brainstate.nn.Module) – The model to fit. Its trainable Param parameters (fit=True) are the optimization variables.

  • optimizer (optional) –

    Interpreted by backend:

    • 'grad': a braintools.optim.OptaxOptimizer instance (e.g. braintools.optim.Adam(lr=0.05)). Defaults to Adam(lr=1e-2).

    • 'nevergrad' / 'scipy': an options dict forwarded to the optimizer constructor (e.g. {'method': 'DE', 'n_sample': 8} or {'method': 'L-BFGS-B'}), a method-name str, or None. The actual optimizer is constructed inside fit() (it needs the loss and bounds) and exposed afterwards as optimizer.

  • loss_fn (callable, optional) – loss_fn(model) -> (scalar_loss, aux). When given it is the entire loss (you own any regularization) and objective / predict / target are unused. Mutually exclusive with predict.

  • objective (callable, optional) – objective(prediction, target) -> scalar (e.g. from brainmass.objectives). Used with predict. Defaults to brainmass.objectives.timeseries_rmse().

  • predict (callable, optional) – predict(model) -> prediction; typically a brainmass.Simulator closure. Required unless loss_fn is given. The objective-path loss is objective(prediction[transient:], target) + model.reg_loss().

  • backend ({'grad', 'nevergrad', 'scipy'}, default 'grad') – Which optimizer backend to use.

  • callbacks (list of callable, optional) – Each callback(info) -> bool | None is called once per step with info = {'step', 'loss', 'best_loss', 'model'}. Returning True stops the run early (grad backend only).

  • transient (int, optional) – Number of leading samples discarded from the prediction (axis 0) before the objective is applied. None keeps the whole prediction.

  • search_space (dict, optional) – {name: (low, high)} constrained bounds for the derivative-free backends, overriding/augmenting the bounds derived from each parameter’s transform.

See also

brainmass.Simulator

builds the predict closure.

brainmass.objectives

composable objective callables.

Notes

The grad backend reproduces the canonical hand-rolled loop exactly: model.states(ParamState) are registered as trainable weights, the loss is evaluated inside model.param_precompute(), and brainstate.transform.grad(..., has_aux=True, return_value=True) feeds optimizer.step. model.reg_loss() is added automatically. The derivative-free backends evaluate one candidate at a time (setting parameters via Param.set_value then running predict) – vmap over parameter writes is unsupported, so candidates are looped, which is why these backends suit a small number of scalar parameters.

Examples

>>> import braintools, brainstate, brainunit as u
>>> import jax.numpy as jnp
>>> import brainmass
>>> from brainstate.nn import Param, SigmoidT
>>> class Toy(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.k = Param(1.0, t=SigmoidT(0.5, 3.0))
...     def update(self):
...         return self.k.value()
>>> def predict(m):
...     sim = brainmass.Simulator(m, dt=0.1 * u.ms)
...     return jnp.mean(sim.run(1.0 * u.ms, monitors=None)['output'])
>>> fitter = brainmass.Fitter(Toy(), braintools.optim.Adam(lr=0.1), predict=predict)
>>> result = fitter.fit(target=jnp.asarray(2.0), n_steps=30)
>>> result.backend
'grad'
>>> list(result.best_params)
['k']
>>> bool(result.best_loss <= result.history[0])
True
__init__(model, optimizer=None, *, loss_fn=None, objective=None, predict=None, backend='grad', callbacks=None, transient=None, search_space=None)[source]#
fit(target=None, n_steps=100, *, verbose=False)[source]#

Run the optimization and return a FitResult.

Parameters:
  • target (Any, optional) – The data the objective compares the prediction against. Unused when a loss_fn was supplied (it consumes data itself).

  • n_steps (int, default 100) – Optimization iterations: optax steps (grad), generations (nevergrad), or random restarts (scipy). Must be >= 1.

  • verbose (bool, default False) – Print per-iteration progress.

Returns:

The best-seen loss, parameters, history, and the fitted model.

Return type:

FitResult

property optimizer#

The underlying optimizer (constructed lazily for derivative-free backends).

class brainmass.FitResult(*, backend, best_loss, best_params, history, n_steps, prediction=None, optimizer=None, raw=None, model=None)#

Outcome of a Fitter.fit() call.

backend#

The optimizer backend that produced this result ('grad' / 'nevergrad' / 'scipy').

Type:

str

best_loss#

Lowest loss observed over the run.

Type:

float

best_params#

{name: value} of the trainable parameters at the best-seen point, in the constrained (physical) space (i.e. Param.value()).

Type:

dict

history#

Per-iteration loss. For grad this is one entry per optimization step; for scipy it is the best loss per restart; for nevergrad it is the loss of every evaluated candidate.

Type:

list of float

n_steps#

Number of optimization iterations actually run (may be less than the requested n_steps if a callback requested early stopping).

Type:

int

prediction#

The model prediction at the best-seen point (objective path), else None.

Type:

Any or None

optimizer#

The underlying braintools.optim optimizer object.

Type:

Any

raw#

Backend-specific raw result (a SciPy OptimizeResult for scipy, the best-parameter mapping for nevergrad, None for grad).

Type:

Any

model#

The fitted model, holding the best-seen parameters.

Type:

brainstate.nn.Module

__init__(*, backend, best_loss, best_params, history, n_steps, prediction=None, optimizer=None, raw=None, model=None)[source]#

Objectives#

Each function is a builder: it takes configuration and returns a small callable(prediction, target) that wraps braintools.metric without reimplementing any metric maths. The callables are designed to be composed via combine() into the loss a fitter minimises.

timeseries_rmse

Build a root-mean-square-error loss between two time series.

fc_corr

Build a functional-connectivity correlation score.

fc_rmse

Build a functional-connectivity RMSE loss.

cosine_sim

Build a cosine-similarity score between two (flattened) time series.

fcd

Build a functional-connectivity-dynamics (FCD) objective.

fcd_distribution

Kernel-density estimate of the FCD off-diagonal value distribution.

ks_distance

Kolmogorov-Smirnov statistic between two 1-D densities / histograms.

wasserstein_1d

Wasserstein-1 distance between two 1-D densities on a shared grid.

fcd_ks

Build a Kolmogorov-Smirnov FCD-distribution loss.

fcd_wasserstein

Build a Wasserstein FCD-distribution loss (smooth, grad-friendly).

combine

Combine weighted objective callables into a single objective.

brainmass.objectives.timeseries_rmse()[source]#

Build a root-mean-square-error loss between two time series.

Returns:

loss(prediction, target) -> scalar computing sqrt(mean((prediction - target) ** 2)). The subtraction is unit-checked: incompatible units raise before the magnitude is taken.

Return type:

callable

See also

fc_rmse

RMSE between functional-connectivity matrices.

Examples

>>> import jax.numpy as jnp
>>> from brainmass import objectives
>>> loss = objectives.timeseries_rmse()
>>> x = jnp.zeros((10, 3))
>>> float(loss(x + 2.0, x))
2.0
brainmass.objectives.fc_corr(as_loss=False)[source]#

Build a functional-connectivity correlation score.

Computes the correlation between the static functional-connectivity (FC) matrices of the prediction and the target via braintools.metric.functional_connectivity() and braintools.metric.matrix_correlation().

Parameters:

as_loss (bool, default False) – If True, return 1 - corr (a quantity to minimise); otherwise return the raw correlation (in [-1, 1], to maximise).

Returns:

score(prediction, target) -> scalar.

Return type:

callable

Examples

>>> import numpy as np, jax.numpy as jnp
>>> from brainmass import objectives
>>> rng = np.random.default_rng(0)
>>> x = jnp.asarray(rng.standard_normal((200, 5)))
>>> score = objectives.fc_corr()
>>> float(score(x, x))
1.0
brainmass.objectives.fc_rmse()[source]#

Build a functional-connectivity RMSE loss.

Returns:

loss(prediction, target) -> scalar computing the RMSE between the prediction’s and target’s static FC matrices.

Return type:

callable

See also

fc_corr

correlation between FC matrices.

brainmass.objectives.cosine_sim(as_loss=False, epsilon=0.0)[source]#

Build a cosine-similarity score between two (flattened) time series.

Thin wrapper over braintools.metric.cosine_similarity(); the inputs are flattened to a single vector so the result is a scalar.

Parameters:
  • as_loss (bool, default False) – If True, return 1 - cos (to minimise); otherwise the raw cosine similarity (to maximise).

  • epsilon (float, default 0.0) – Numerical floor forwarded to braintools.metric.cosine_similarity().

Returns:

score(prediction, target) -> scalar.

Return type:

callable

brainmass.objectives.fcd(window_size=30, step_size=5, as_loss=False)[source]#

Build a functional-connectivity-dynamics (FCD) objective.

Surfaces braintools.metric.functional_connectivity_dynamics(), which had no call sites in the package. The FCD matrix captures how the sliding-window functional connectivity itself evolves over time.

Parameters:
  • window_size (int) – Sliding-window length and stride (in samples) forwarded to braintools.metric.functional_connectivity_dynamics().

  • step_size (int) – Sliding-window length and stride (in samples) forwarded to braintools.metric.functional_connectivity_dynamics().

  • as_loss (bool, default False) – If True, return 1 - corr (to minimise); otherwise the raw FCD matrix correlation (to maximise).

Returns:

fn(prediction, target=None). With target=None it returns the prediction’s FCD matrix (surfacing the metric); with a target it returns the correlation between the two FCD matrices.

Return type:

callable

brainmass.objectives.combine(*weighted_objectives)[source]#

Combine weighted objective callables into a single objective.

Parameters:

*weighted_objectives (tuple of (float, callable)) – Pairs of (weight, objective) where each objective is a callable returned by the builders in this module (signature objective(prediction, target)).

Returns:

loss(prediction, target=None) -> scalar equal to sum(weight * objective(prediction, target)).

Return type:

callable

Examples

>>> import jax.numpy as jnp
>>> from brainmass import objectives
>>> loss = objectives.combine(
...     (2.0, objectives.timeseries_rmse()),
...     (0.5, objectives.timeseries_rmse()),
... )
>>> x = jnp.zeros((10, 3))
>>> float(loss(x + 1.0, x))   # (2.0 + 0.5) * 1.0
2.5

FCD-distribution objectives#

The standard FCD fitting target is the distribution of the FCD matrix’s off-diagonal values, not the matrix correlation (fcd()). These builders compare that distribution between prediction and target: each computes the FCD of both, kernel-density-estimates the off-diagonal values onto a shared grid, and returns a distributional distance.

wasserstein_1d() is smooth and differentiable, so fcd_wasserstein() is the recommended FCD objective for gradient-based fitting; ks_distance() (the literature-standard Kolmogorov-Smirnov statistic) is a non-smooth supremum, so fcd_ks() is best for evaluation / reporting. A degenerate (constant, zero-variance) input yields a singular KDE and a nan distance.

brainmass.objectives.fcd_distribution(fcd_matrix, midpoints=None, n_diag=1, bw_method=None, normalize=True)[source]#

Kernel-density estimate of the FCD off-diagonal value distribution.

The standard FCD fitting target is the distribution of the upper-triangle (off-diagonal) values of the FCD matrix – not the matrix itself. This surfaces that distribution as a smooth density on a fixed grid via jax.scipy.stats.gaussian_kde() (delegated; braintools provides no KDE).

Parameters:
  • fcd_matrix (array) – Square FCD matrix (e.g. from fcd() or braintools.metric.functional_connectivity_dynamics()).

  • midpoints (array, optional) – Evaluation grid. Default: 100 points on [-0.99, 0.99] (FCD values are correlations).

  • n_diag (int, default 1) – Diagonal offset for the upper-triangle extraction (1 excludes the main diagonal).

  • bw_method (optional) – Bandwidth selector forwarded to jax.scipy.stats.gaussian_kde() (default None = Scott’s rule). Smaller values give sharper densities.

  • normalize (bool, default True) – Renormalise the evaluated density to integrate to 1 over midpoints (KDE on a finite grid never integrates to exactly 1).

Returns:

Density evaluated on midpoints.

Return type:

jax.Array

brainmass.objectives.ks_distance(p, q)[source]#

Kolmogorov-Smirnov statistic between two 1-D densities / histograms.

Each cumulative sum is normalised to a proper CDF before the supremum is taken, so the result lies in [0, 1] independent of bin width and is directly comparable to scipy.stats.ks_2samp().

\[D = \sup_x \left| F_p(x) - F_q(x) \right|.\]
Parameters:
  • p (array) – Densities or (unnormalised) histograms on a shared, ordered grid.

  • q (array) – Densities or (unnormalised) histograms on a shared, ordered grid.

Returns:

Scalar KS statistic in [0, 1].

Return type:

jax.Array

Notes

The supremum (max) makes this non-smooth: its gradient is the indicator at the argmax, so prefer wasserstein_1d() (and fcd_wasserstein()) when the distance is a fitting loss. Use KS for evaluation / reporting, where literature comparability matters.

See also

wasserstein_1d

smooth, gradient-friendly distributional distance.

brainmass.objectives.wasserstein_1d(p, q, x)[source]#

Wasserstein-1 distance between two 1-D densities on a shared grid.

\[W_1(p, q) = \int \left| F_p(x) - F_q(x) \right| \, dx,\]

discretised on the uniform grid x. Both inputs are normalised internally (CDFs run 0..1), so p and q may be densities or unnormalised histograms; the returned value lives in the same units as x. Unlike ks_distance() this is smooth and differentiable in both inputs, making it the preferred distributional loss for gradient-based fitting. Matches scipy.stats.wasserstein_distance() as the grid is refined.

Parameters:
  • p (array) – Densities or histograms on the shared grid x.

  • q (array) – Densities or histograms on the shared grid x.

  • x (array) – Uniform evaluation grid (same length as p / q).

Returns:

Scalar Wasserstein-1 distance.

Return type:

jax.Array

See also

ks_distance

the non-smooth KS counterpart.

brainmass.objectives.fcd_ks(window_size=30, step_size=5, midpoints=None, bw_method=None, n_diag=1)[source]#

Build a Kolmogorov-Smirnov FCD-distribution loss.

Like fcd_wasserstein() but using the ks_distance() statistic between the two FCD off-diagonal distributions. The KS statistic is the common literature metric for FCD comparison, but it is non-smooth (a supremum); prefer fcd_wasserstein() when the loss drives a gradient optimiser, and use this for evaluation / reporting.

Parameters:
  • window_size (int) – Sliding-window length and stride (in samples) for braintools.metric.functional_connectivity_dynamics().

  • step_size (int) – Sliding-window length and stride (in samples) for braintools.metric.functional_connectivity_dynamics().

  • midpoints (array, optional) – FCD-value evaluation grid (default 100 points on [-0.99, 0.99]).

  • bw_method (optional) – KDE bandwidth forwarded to fcd_distribution().

  • n_diag (int, default 1) – Upper-triangle diagonal offset.

Returns:

loss(prediction, target) -> scalar KS distance between the two FCD distributions (a quantity to minimise; 0 on identity).

Return type:

callable

See also

fcd_wasserstein

the smooth, gradient-friendly counterpart.

brainmass.objectives.fcd_wasserstein(window_size=30, step_size=5, midpoints=None, bw_method=None, n_diag=1)[source]#

Build a Wasserstein FCD-distribution loss (smooth, grad-friendly).

Compares the distribution of FCD off-diagonal values of the prediction and target – the standard FCD fitting target – rather than the FCD matrix correlation (fcd()). The wasserstein_1d() distance is smooth, so this is the recommended FCD objective for gradient-based fitting.

Parameters:
  • window_size (int) – Sliding-window length and stride (in samples) for braintools.metric.functional_connectivity_dynamics().

  • step_size (int) – Sliding-window length and stride (in samples) for braintools.metric.functional_connectivity_dynamics().

  • midpoints (array, optional) – FCD-value evaluation grid (default 100 points on [-0.99, 0.99]).

  • bw_method (optional) – KDE bandwidth forwarded to fcd_distribution().

  • n_diag (int, default 1) – Upper-triangle diagonal offset.

Returns:

loss(prediction, target) -> scalar Wasserstein-1 distance between the two FCD distributions (a quantity to minimise; 0 on identity).

Return type:

callable

See also

fcd_ks

the non-smooth KS counterpart.

fcd

FCD matrix correlation objective.

Examples

>>> import numpy as np, jax.numpy as jnp
>>> from brainmass import objectives
>>> rng = np.random.default_rng(0)
>>> x = jnp.asarray(rng.standard_normal((200, 6)))
>>> loss = objectives.fcd_wasserstein(window_size=30, step_size=5)
>>> float(loss(x, x))
0.0

See Also#