{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "da45de76",
   "metadata": {},
   "source": [
    "# Modeling resting-state MEG data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34972e34",
   "metadata": {},
   "source": [
    "In this example, we will simlate resting state functional connectivity of MEG recordings."
   ]
  },
  {
   "metadata": {
    "ExecuteTime": {
     "start_time": "2026-03-21T13:55:36.384061800Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import brainstate\n",
    "import braintools\n",
    "import brainmass\n",
    "import brainunit as u\n",
    "import numpy as np"
   ],
   "id": "2f22773ecc8e269c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "brainstate.environ.set(dt=0.1 * u.ms)",
   "id": "23ce42cabc324b2",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "69f63761",
   "metadata": {},
   "source": [
    "## Perform simulations to generate synthetic data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e975acfc",
   "metadata": {},
   "source": [
    "<h1 style=\"font-size:1em;\">Load MEG data and use wilson-wowan model to build the brain network.</h1>"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import os.path\n",
    "import kagglehub\n",
    "path = kagglehub.dataset_download(\"oujago/hcp-gw-data-samples\")\n",
    "hcp = braintools.file.msgpack_load(os.path.join(path, \"hcp-data-sample.msgpack\"))"
   ],
   "id": "e7b4189fa6cc7d67",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "class Network(brainstate.nn.Module):\n",
    "    def __init__(self, signal_speed=2., k=1.):\n",
    "        super().__init__()\n",
    "\n",
    "        conn_weight = hcp['Cmat'].copy()\n",
    "        np.fill_diagonal(conn_weight, 0)\n",
    "        delay_time = hcp['Dmat'].copy() / signal_speed\n",
    "        np.fill_diagonal(delay_time, 0)\n",
    "        delay_time = delay_time * u.ms\n",
    "        indices_ = np.tile(np.arange(conn_weight.shape[1]), conn_weight.shape[0])\n",
    "\n",
    "        self.node = brainmass.WilsonCowanStep(\n",
    "            80,\n",
    "            noise_E=brainmass.OUProcess(80, sigma=0.01),\n",
    "            noise_I=brainmass.OUProcess(80, sigma=0.01),\n",
    "        )\n",
    "        self.coupling = brainmass.DiffusiveCoupling(\n",
    "            self.node.prefetch_delay('rE', (delay_time.flatten(), indices_), init=braintools.init.Uniform(0, 0.05)),\n",
    "            self.node.prefetch('rE'),\n",
    "            conn_weight,\n",
    "            k=k\n",
    "        )\n",
    "\n",
    "    def update(self):\n",
    "        current = self.coupling()\n",
    "        rE = self.node(current)\n",
    "        return rE\n",
    "\n",
    "    def step_run(self, i):\n",
    "        with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):\n",
    "            return self.update()"
   ],
   "id": "315bfde036e77d48",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "net = Network()\n",
    "brainstate.nn.init_all_states(net)\n",
    "indices = np.arange(0, 6e3 // (brainstate.environ.get_dt().to_decimal(u.ms)))\n",
    "exes = brainstate.transform.for_loop(net.step_run, indices)"
   ],
   "id": "5ca15357649a7210",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['image.cmap'] = 'plasma'\n",
    "\n",
    "fig, gs = braintools.visualize.get_figure(1, 2, 4, 6)\n",
    "ax1 = fig.add_subplot(gs[0, 0])\n",
    "fc = braintools.metric.functional_connectivity(exes)\n",
    "ax = ax1.imshow(fc)\n",
    "plt.colorbar(ax, ax=ax1)\n",
    "ax2 = fig.add_subplot(gs[0, 1])\n",
    "ax2.plot(indices, exes[:, ::5], alpha=0.8)\n",
    "plt.show()"
   ],
   "id": "fabdb043337c9c6a",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "40fdec32",
   "metadata": {},
   "source": [
    "Next, we perform a series of neural-dynamics processing steps on the raw simulated activity—resampling, band-pass filtering, amplitude-envelope extraction, and low-pass filtering—to bring the synthetic signal into closer agreement with real MEG data.\n"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import xarray as xr\n",
    "from scipy.signal import resample, butter, filtfilt, hilbert\n",
    "import seaborn as sns\n",
    "import ipywidgets as widgets"
   ],
   "id": "68c5e3d40bfd7e",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "dt_ms = brainstate.environ.get_dt() / u.ms\n",
    "original_fs = 1000. / dt_ms \n",
    "num_regions = exes.shape[1]\n",
    "time_points_orig = indices * dt_ms / 1000. \n",
    "\n",
    "region_labels = [f'Region_{i}' for i in range(num_regions)]\n",
    "sim_signal_raw = xr.DataArray(\n",
    "    exes, \n",
    "    dims=(\"time\", \"regions\"), \n",
    "    coords={\"time\": time_points_orig, \"regions\": region_labels}\n",
    ")"
   ],
   "id": "120e483b43af8c83",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "486f31d2",
   "metadata": {},
   "source": [
    "Down-sampling the ultra-high 10 kHz simulated signal to 100 Hz so that subsequent analyses can be carried out within the 0–50 Hz range, just as with real MEG data."
   ]
  },
  {
   "cell_type": "code",
   "id": "f8a6f260",
   "metadata": {},
   "source": [
    "# resample\n",
    "target_fs = 100.0 \n",
    "num_samples_new = int(len(time_points_orig) * target_fs / original_fs)\n",
    "resampled_data, resampled_time = resample(sim_signal_raw, num_samples_new, t=time_points_orig)\n",
    "\n",
    "sim_signal = xr.DataArray(\n",
    "    resampled_data,\n",
    "    dims=(\"time\", \"regions\"),\n",
    "    coords={\"time\": resampled_time, \"regions\": region_labels}\n",
    ")\n",
    "print(f\"Original sampling rate: {original_fs} Hz. Resampled to: {target_fs} Hz.\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "9ea7245c",
   "metadata": {},
   "source": [
    "Define a band-pass and a low-pass Butterworth filter to prepare for subsequent frequency-specific analyses."
   ]
  },
  {
   "cell_type": "code",
   "id": "5b5de399",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "source": [
    "#  bandpass filter\n",
    "def butter_bandpass_filter(data, lowcut, highcut, fs, order=3):\n",
    "    nyq = 0.5 * fs\n",
    "    low = lowcut / nyq\n",
    "    high = highcut / nyq\n",
    "    b, a = butter(order, [low, high], btype='band')\n",
    "    y = filtfilt(b, a, data)\n",
    "    return y\n",
    "\n",
    "# low-pass filter\n",
    "def butter_lowpass_filter(data, highcut, fs, order=3):\n",
    "    nyq = 0.5 * fs\n",
    "    high = highcut / nyq\n",
    "    b, a = butter(order, [high], btype='low')\n",
    "    y = filtfilt(b, a, data)\n",
    "    return y"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "96b22313",
   "metadata": {},
   "source": [
    "Filter the signal to the 8–12 Hz α-band, apply the Hilbert transform to obtain the analytic signal and extract the instantaneous amplitude envelope, and finally low-pass filter it to reveal the slow dynamics."
   ]
  },
  {
   "cell_type": "code",
   "id": "6824820a",
   "metadata": {},
   "source": [
    "target_region = 'Region_0' # 选择第一个脑区进行分析\n",
    "signal_to_process = sim_signal.sel(regions=target_region).values\n",
    "\n",
    "# bandpass filter\n",
    "freq_band = [8.0, 12.0]    # α\n",
    "filtered_signal = butter_bandpass_filter(signal_to_process, freq_band[0], freq_band[1], fs=target_fs)\n",
    "\n",
    "# hilbert\n",
    "analytic_signal = hilbert(filtered_signal)\n",
    "signal_envelope = np.abs(analytic_signal)\n",
    "\n",
    "# low-pass\n",
    "low_pass_cutoff = 4.0\n",
    "smoothed_envelope = butter_lowpass_filter(signal_envelope, low_pass_cutoff, fs=target_fs)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "3b754121",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "fig_signal, ax_signal = plt.subplots(figsize=(15, 5))\n",
    "plot_duration_s = 4 # 画4秒\n",
    "plot_timepoints = int(plot_duration_s * target_fs)\n",
    "time_axis = sim_signal.time.values[:plot_timepoints]\n",
    "ax_signal.plot(time_axis, filtered_signal[:plot_timepoints], label=f'Filtered Signal ({freq_band[0]}-{freq_band[1]} Hz)', alpha=0.8)\n",
    "ax_signal.plot(time_axis, signal_envelope[:plot_timepoints], label='Signal Envelope', linewidth=2, color='red')\n",
    "ax_signal.plot(time_axis, smoothed_envelope[:plot_timepoints], label=f'Low-Pass Envelope (<{low_pass_cutoff} Hz)', linewidth=2.5, color='black')\n",
    "ax_signal.set_title(f'Signal Processing for {target_region}')\n",
    "ax_signal.set_xlabel('Time (s)')\n",
    "ax_signal.set_ylabel('Amplitude')\n",
    "ax_signal.legend(loc='upper right')\n",
    "ax_signal.spines['top'].set_visible(False)\n",
    "ax_signal.spines['right'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "e722502d",
   "metadata": {},
   "source": [
    "## Process real MEG data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15ae6b71",
   "metadata": {},
   "source": [
    "load MEG data"
   ]
  },
  {
   "cell_type": "code",
   "id": "500f57b6",
   "metadata": {},
   "source": [
    "path = kagglehub.dataset_download(\"oujago/meg-data-samples\")\n",
    "meg_data = xr.open_dataset(os.path.join(path, 'rs-meg.nc'))\n",
    "\n",
    "region_labels = meg_data.regions.values\n",
    "sampling_rate = 1 / (meg_data.time[1] - meg_data.time[0]).item() # Calculate sampling rate from time coordinates\n",
    "print(f\"Array loaded with {len(region_labels)} regions. Sampling rate: {sampling_rate:.2f} Hz\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "bc7b1467",
   "metadata": {},
   "source": [
    "print('Select a region from the AAL2 atlas and a frequency range')\n",
    "# Select a Region \n",
    "target = widgets.Select(options=region_labels, value='PreCG.L', description='Regions', \n",
    "                        tooltips=['Description of slow', 'Description of regular', 'Description of fast'], \n",
    "                        layout=widgets.Layout(width='50%', height='150px'))\n",
    "display(target)\n",
    "\n",
    "# Select Frequency Range\n",
    "freq = widgets.IntRangeSlider(min=1, max=46, description='Frequency (Hz)', value=[8, 12], layout=widgets.Layout(width='80%'), \n",
    "                              style={'description_width': 'initial'})\n",
    "display(freq)\n",
    "\n",
    "plot_timepoints = 1000\n",
    "\n",
    "meg_array = meg_data['__xarray_dataarray_variable__']\n",
    "if meg_array.dims[0] != 'time':\n",
    "    meg_array = meg_array.transpose('time', 'regions')    \n",
    "fs = sampling_rate                                   \n",
    "low, high = freq.value                              \n",
    "plot_len = int(plot_timepoints)                     \n",
    "time_vals = meg_array.time[:plot_len].values "
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "4c977ba2",
   "metadata": {},
   "source": [
    "Filter the time series of the selected brain region and plot both the original and filtered signals."
   ]
  },
  {
   "cell_type": "code",
   "id": "22fd0d83",
   "metadata": {},
   "source": [
    "fig, ax = plt.subplots(2, 1, figsize=(12,8), sharex=True)\n",
    "y_raw_target = meg_array.sel(regions=target.value).values \n",
    "sns.lineplot(x=time_vals, y=y_raw_target[:plot_len], ax=ax[0], color='k', alpha=0.6)\n",
    "ax[0].set_title(f'Unfiltered Signal ({target.value})')\n",
    "\n",
    "# Band Pass Filter the Signal\n",
    "y_filt_target = butter_bandpass_filter(y_raw_target, low, high, fs)\n",
    "y_env_target  = np.abs(hilbert(y_filt_target))  \n",
    "\n",
    "sns.lineplot(x=time_vals, y=y_filt_target[:plot_len], ax=ax[1], label='Bandpass-Filtered Signal')\n",
    "sns.lineplot(x=time_vals, y=y_env_target[:plot_len], ax=ax[1], label='Signal Envelope')\n",
    "ax[1].set_title(f'Filtered Signal ({target.value})');\n",
    "ax[1].legend(bbox_to_anchor=(1.2, 1),borderaxespad=0)\n",
    "sns.despine(trim=True)\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "b8986368",
   "metadata": {},
   "source": [
    "Next, we orthogonalize the signals. MEG recordings are mixtures of multiple neural sources, and electric-field spread creates spurious functional connections—two sensors may pick up the same underlying source. To counter this, we remove the component of each target signal that shares phase with a reference region, preserving only the neural activity that is independent of that reference. This orthogonalization improves the spatial specificity and reliability of MEG connectivity estimates."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d05ac9c1",
   "metadata": {},
   "source": [
    "$$\n",
    "Y_{ix}(t, f) = \\text{imag}\\left(\\frac{Y(t, f)}{X(t, f)^*}\\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6dec9b8",
   "metadata": {},
   "source": [
    " Y represents the analytic signal from our target regions that is being orthogonalized with respect to the signal from region X."
   ]
  },
  {
   "cell_type": "code",
   "id": "b408b020",
   "metadata": {},
   "source": [
    "# Orthogonalized signal envelope\n",
    "print('Select a reference region for the orthogonalization')\n",
    "# Select a Region \n",
    "referenz = widgets.Select(options=region_labels, value='PreCG.R', description='Regions',\n",
    "                          tooltips=['Description of slow', 'Description of regular', 'Description of fast'],\n",
    "                          layout=widgets.Layout(width='50%', height='150px'))\n",
    "display(referenz)\n",
    "\n",
    "y_raw_reference = meg_array.sel(regions=referenz.value).values\n",
    "y_filt_reference = butter_bandpass_filter(y_raw_reference, low, high, fs)\n",
    "\n",
    "complex_target = hilbert(y_filt_target)\n",
    "complex_reference = hilbert(y_filt_reference)\n",
    "env_reference = np.abs(complex_reference)\n",
    "env_target = np.abs(complex_target)\n",
    "\n",
    "signal_conj_div_env = complex_reference.conj() / env_reference\n",
    "orth_signal_imag = (complex_target * signal_conj_div_env).imag\n",
    "orth_env = np.abs(orth_signal_imag)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "ae5aaff8",
   "metadata": {},
   "source": [
    "fig_final, ax_final = plt.subplots(2, 1, figsize=(12, 8), sharex=True)\n",
    "ax_final[0].plot(time_vals, y_filt_reference[:plot_len], label=f'Filtered Signal ({referenz.value})')\n",
    "ax_final[0].plot(time_vals, env_reference[:plot_len], label=f'Signal Envelope ({referenz.value})', color='red', linewidth=2)\n",
    "ax_final[0].set_title(f'Reference Region X ({referenz.value})')\n",
    "ax_final[0].legend()\n",
    "\n",
    "ax_final[1].plot(time_vals, y_filt_target[:plot_len], label='Bandpass-Filtered Signal', alpha=0.8)\n",
    "ax_final[1].plot(time_vals, env_target[:plot_len], label='Signal Envelope', linewidth=2.5)\n",
    "ax_final[1].plot(time_vals, orth_env[:plot_len], label='Orthogonalized Envelope', linewidth=2.5, linestyle='--')\n",
    "ax_final[1].set_title(f'Target Region Y ({target.value})')\n",
    "ax_final[1].legend(bbox_to_anchor=(1.2, 1), borderaxespad=0)\n",
    "\n",
    "sns.despine(trim=True)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "c047b64c",
   "metadata": {},
   "source": [
    "Apply an identical low-pass filter to both the reference envelope and the orthogonalized envelope to isolate their slow-frequency components."
   ]
  },
  {
   "cell_type": "code",
   "id": "a4b3af73",
   "metadata": {},
   "source": [
    "low_pass = widgets.FloatSlider(value=1.0, min=0, max=2.0, step=0.1, description='Low-Pass Frequency (Hz)', \n",
    "                               disabled=False, readout=True, readout_format='.1f', layout=widgets.Layout(width='80%'), \n",
    "                               style={'description_width': 'initial'})\n",
    "display(low_pass)\n",
    "\n",
    "low_orth_env = butter_lowpass_filter(orth_env, highcut=low_pass.value, fs=fs)\n",
    "low_signal_env = butter_lowpass_filter(env_reference, highcut=low_pass.value, fs=fs)\n",
    "\n",
    "# Plot the smoothed envelopes for comparison\n",
    "fig_corr, ax_corr = plt.subplots(1, 2, figsize=(15, 4), sharey=True)\n",
    "\n",
    "# Plot for the reference region\n",
    "ax_corr[0].plot(time_vals, env_reference[:plot_len], alpha=0.5, label='Original Envelope')\n",
    "ax_corr[0].plot(time_vals, low_signal_env[:plot_len], color='black', lw=2, label=f'Smoothed Envelope (<{low_pass.value} Hz)')\n",
    "ax_corr[0].set_title(f'Reference Region X ({referenz.value})')\n",
    "ax_corr[0].legend()\n",
    "\n",
    "# Plot for the target region\n",
    "ax_corr[1].plot(time_vals, orth_env[:plot_len], alpha=0.5, label='Orthogonalized Envelope')\n",
    "ax_corr[1].plot(time_vals, low_orth_env[:plot_len], color='black', lw=2, label='Smoothed Orthogonalized Env.')\n",
    "ax_corr[1].set_title(f'Target Region Y ({target.value})')\n",
    "ax_corr[1].legend()\n",
    "\n",
    "sns.despine(trim=True)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "correlation = np.corrcoef(low_orth_env, low_signal_env)[0, 1]\n",
    "print(f'\\nOrthogonalized envelope correlation between {referenz.value} and {target.value}: {correlation:.2f}')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "de94e7dd",
   "metadata": {},
   "source": [
    "## Compute the whole-brain functional-connectivity matrix"
   ]
  },
  {
   "cell_type": "code",
   "id": "dd42e163",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "source": [
    "def calculate_orthogonalized_fc(meg_data, band, fs, low_pass_cutoff):\n",
    "    num_regions = meg_data.shape[1]\n",
    "    \n",
    "    # Z-score standardize the raw signals from all brain regions to ensure robustness of the results.\n",
    "    meg_array_mean = meg_data.mean(dim='time')\n",
    "    meg_array_std = meg_data.std(dim='time')\n",
    "    meg_data_normalized = (meg_data - meg_array_mean) / meg_array_std\n",
    "    \n",
    "    # Apply band-pass and low-pass filtering\n",
    "    y_filt_all = butter_bandpass_filter(meg_data_normalized.values, band[0], band[1], fs)\n",
    "    complex_all = hilbert(y_filt_all, axis=0)\n",
    "    env_all = np.abs(complex_all) \n",
    "    conj_div_env_all = complex_all.conj() / env_all\n",
    "    low_pass_env_all = butter_lowpass_filter(env_all, highcut=low_pass_cutoff, fs=fs)\n",
    "\n",
    "    # Iteratively compute orthogonalization and correlation\n",
    "    fc_matrix = np.zeros((num_regions, num_regions))\n",
    "    \n",
    "    progress = widgets.IntProgress(min=0, max=num_regions, description='Computing FC:',\n",
    "                                   layout=widgets.Layout(width='80%'), style={'description_width': 'initial'})\n",
    "    display(progress)\n",
    "    \n",
    "    for i in range(num_regions):\n",
    "        complex_target = complex_all[:, i]\n",
    "        orth_signal_imag = (complex_target[:, np.newaxis] * conj_div_env_all).imag\n",
    "        orth_env = np.abs(orth_signal_imag)\n",
    "        low_pass_orth_env = butter_lowpass_filter(orth_env, highcut=low_pass_cutoff, fs=fs)\n",
    "        corr_mat = np.corrcoef(low_pass_orth_env.T, low_pass_env_all.T)\n",
    "        corr_row = np.diag(corr_mat, k=num_regions)\n",
    "        fc_matrix[i, :] = corr_row\n",
    "        progress.value += 1\n",
    "        \n",
    "    fc_matrix = (fc_matrix + fc_matrix.T) / 2\n",
    "    np.fill_diagonal(fc_matrix, 0)\n",
    "    \n",
    "    print(\"FC Matrix calculation complete.\")\n",
    "    return fc_matrix"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "279a3542",
   "metadata": {},
   "source": [
    "Plot settings"
   ]
  },
  {
   "cell_type": "code",
   "id": "cdf9511a",
   "metadata": {},
   "source": [
    "def plot_fc_heatmap(fc_matrix, labels, title, ax):\n",
    "    vmax = np.max(np.abs(fc_matrix))\n",
    "    \n",
    "    sns.heatmap(\n",
    "        fc_matrix,\n",
    "        square=True,\n",
    "        ax=ax,\n",
    "        cmap='RdBu_r',\n",
    "        vmin=-vmax,\n",
    "        vmax=vmax,\n",
    "        cbar_kws={\"shrink\": .8}\n",
    "    )\n",
    "    \n",
    "    num_regions = len(labels)\n",
    "    tick_spacing = max(1, num_regions // 20)\n",
    "    \n",
    "    cleaned_labels = [str(label).replace('.L', '').replace('.R', '') for label in labels]\n",
    "    \n",
    "    ax.set_xticks(np.arange(0, num_regions, tick_spacing))\n",
    "    ax.set_yticks(np.arange(0, num_regions, tick_spacing))\n",
    "    ax.set_xticklabels(cleaned_labels[::tick_spacing], rotation=90, fontsize=8)\n",
    "    ax.set_yticklabels(cleaned_labels[::tick_spacing], rotation=0, fontsize=8)\n",
    "    ax.set_title(title)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "f17cb85d",
   "metadata": {},
   "source": [
    "Compute the functional-connectivity matrix from the real MEG data"
   ]
  },
  {
   "cell_type": "code",
   "id": "c4ae5929",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "# Execute the function with the parameters from the widgets\n",
    "fc_matrix = calculate_orthogonalized_fc(\n",
    "    meg_data=meg_array, \n",
    "    band=freq.value, \n",
    "    fs=fs, \n",
    "    low_pass_cutoff=low_pass.value\n",
    ")\n",
    "\n",
    "print(\"\\nPlotting FC Matrix for real MEG data (Default Order)...\")\n",
    "\n",
    "fig_real, ax_real = plt.subplots(figsize=(10, 8))\n",
    "plot_fc_heatmap(\n",
    "    fc_matrix=fc_matrix,\n",
    "    labels=region_labels, \n",
    "    title=f'Real Array Orthogonalized FC ({freq.value[0]}-{freq.value[1]} Hz)',\n",
    "    ax=ax_real\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "d98e5069",
   "metadata": {},
   "source": [
    "Compute the functional-connectivity matrix from the synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "id": "423a4146",
   "metadata": {},
   "source": [
    "print(\"\\nPlease select parameters for the simulation FC calculation:\")\n",
    "sim_freq_band = widgets.IntRangeSlider(\n",
    "    value=[8, 12], min=1, max=45, description='Frequency (Hz)',\n",
    "    layout=widgets.Layout(width='80%'), style={'description_width': 'initial'}\n",
    ")\n",
    "sim_low_pass = widgets.FloatSlider(\n",
    "    value=2.0, min=0, max=4.0, step=0.1, description='Low-Pass (Hz)',\n",
    "    readout=True, readout_format='.1f',\n",
    "    layout=widgets.Layout(width='80%'), style={'description_width': 'initial'}\n",
    ")\n",
    "display(sim_freq_band)\n",
    "display(sim_low_pass)\n",
    "\n",
    "fc_matrix_sim = calculate_orthogonalized_fc(\n",
    "    meg_data=sim_signal,\n",
    "    band=sim_freq_band.value,\n",
    "    fs=target_fs,\n",
    "    low_pass_cutoff=sim_low_pass.value\n",
    ")\n",
    "\n",
    "sim_labels = sim_signal.regions.values\n",
    "fig_sim, ax_sim = plt.subplots(figsize=(10, 8))\n",
    "plot_fc_heatmap(\n",
    "    fc_matrix=fc_matrix_sim,\n",
    "    labels=sim_labels, \n",
    "    title=f'Simulated Orthogonalized FC ({sim_freq_band.value[0]}-{sim_freq_band.value[1]} Hz)',\n",
    "    ax=ax_sim\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "d48caa48",
   "metadata": {},
   "source": [
    "Flatten the simulated and real FC matrices into vectors, compute the Pearson correlation, and display the model fit with a bar plot to visually assess how well the simulated network reproduces the empirical MEG functional connectivity."
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "num_sim_regions = fc_matrix_sim.shape[0]\n",
    "fc_matrix_real_matched = fc_matrix[:num_sim_regions, :num_sim_regions]\n",
    "print(f\"\\nComparing the {num_sim_regions}x{num_sim_regions} simulated FC with the corresponding subset of the real MEG FC.\")\n",
    "indices_triu = np.triu_indices(num_sim_regions, k=1)\n",
    "real_fc_flat = fc_matrix_real_matched[indices_triu]\n",
    "sim_fc_flat = fc_matrix_sim[indices_triu]\n",
    "\n",
    "# Calculate the Pearson correlation between the two flattened FC vectors\n",
    "sim_vs_real_corr = np.corrcoef(real_fc_flat, sim_fc_flat)[0, 1]\n",
    "fig_comp, ax_comp = plt.subplots(figsize=(6, 6))\n",
    "\n",
    "splot = sns.barplot(\n",
    "    x=['Simulated vs. Real Array'],\n",
    "    y=[sim_vs_real_corr],\n",
    "    ax=ax_comp\n",
    ")\n",
    "\n",
    "ax_comp.set_ylabel('Pearson Correlation')\n",
    "ax_comp.set_title('Model Fit: Correlation between Simulated and Real FC', pad=15)\n",
    "ax_comp.set_ylim(0, max(0.5, sim_vs_real_corr * 1.2)) \n",
    "\n",
    "# Add the correlation value as text on the bar\n",
    "for p in splot.patches:\n",
    "    splot.annotate(\n",
    "        format(p.get_height(), '.3f'), \n",
    "        (p.get_x() + p.get_width() / 2., p.get_height()), \n",
    "        ha='center', va='center', \n",
    "        size=18, color='white',\n",
    "        xytext=(0, -15), \n",
    "        textcoords='offset points'\n",
    "    )\n",
    "\n",
    "sns.despine()\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "print(f\"The correlation between the simulated FC and the real MEG data FC is: {sim_vs_real_corr:.3f}\")"
   ],
   "id": "aab13165b4b14bfe",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
