Skip to content

Colab   Kaggle   Binder

Reactivation

Here, we will show how to use the AssemblyReact class to identify assemblies and assess reactivation during post-task sleep


Setup

%reload_ext autoreload
%autoreload 2

# from neuro_py
# plotting
import matplotlib.pyplot as plt

# core tools
import nelpy as nel
import numpy as np
import seaborn as sns

from neuro_py.ensemble.assembly_reactivation import AssemblyReact
from neuro_py.io import loading
from neuro_py.process.peri_event import event_triggered_average_fast
import neuro_py as npy

Section 1: Pick basepath and initialize AssemblyReact class

Here we will use CA1 pyramidal cells.

basepath = r"S:\data\HMC\HMC1\day8"

assembly_react = AssemblyReact(
    basepath=basepath,
    brainRegion="CA1",
    putativeCellType="Pyr",
    z_mat_dt=0.01,
)

Also, load brain states for later use.

# load theta epochs
state_dict = loading.load_SleepState_states(basepath)
theta_epochs = nel.EpochArray(
    state_dict["THETA"],
)
nrem_epochs = nel.EpochArray(
    state_dict["NREMstate"],
)
theta_epochs, nrem_epochs
(<EpochArray at 0x2206cd20ad0: 125 epochs> of length 35:04 minutes,
 <EpochArray at 0x2206caea810: 88 epochs> of length 2:16:25 hours)

Section 2: Load spike data, session epochs, and ripple events

You can see there there are nice printouts that display important information about the class

# load need data (spikes, ripples, epochs)
assembly_react.load_data()
assembly_react
<AssemblyReact: 75 units> of length 6:36:57:689 hours

Locate the session from which you want to detect assemblies.

Here we can see a novel linear track is the second epoch.

assembly_react.epoch_df
name startTime stopTime environment behavioralParadigm notes manipulation stimuli basepath
0 preSleep_210411_064951 0.0 9544.56315 sleep NaN NaN NaN NaN S:\data\HMC\HMC1\day8
1 maze_210411_095201 9544.5632 11752.80635 linear 1 novel NaN NaN S:\data\HMC\HMC1\day8
2 postSleep_210411_103522 11752.8064 23817.68955 sleep NaN NaN NaN NaN S:\data\HMC\HMC1\day8

Section 3: Detect assembles in linear track during theta

You can see we have detected 15 assemblies

assembly_react.get_weights(epoch=assembly_react.epochs[1] & theta_epochs)
assembly_react
<AssemblyReact: 75 units, 15 assemblies> of length 6:36:57:689 hours

Section 4: Analyze the obtained assemblies

Section 4.1: Visualize assembly weights

Each column is a assembly and each row is a cell

The color indicates if the cell was a significant contributor (members) to that assembly * you can find these members with assembly_members = assembly_react.find_members()

assembly_react.plot()
plt.show()

png

Section 4.2: Compute time-resolved activations for each assembly

Will take around a minute to run.

assembly_act = assembly_react.get_assembly_act()
assembly_act
<AnalogSignalArray at 0x22080b60d10: 15 signals> for a total of 6:36:57:680 hours

Section 4.3: Get assembly strengths around ripples in pre-sleep, the task, and in post-sleep epochs

nrem_ripples = assembly_react.ripples & nrem_epochs

psth_swr_pre = event_triggered_average_fast(
    assembly_act.data,
    nrem_ripples[assembly_react.epochs[0]].starts,
    sampling_rate=assembly_act.fs,
    window=[-0.5, 0.5],
    return_average=True,
    return_pandas=True,
)
psth_swr_task = event_triggered_average_fast(
    assembly_act.data,
    assembly_react.ripples[assembly_react.epochs[1]].starts,
    sampling_rate=assembly_act.fs,
    window=[-0.5, 0.5],
    return_average=True,
    return_pandas=True,
)
psth_swr_post = event_triggered_average_fast(
    assembly_act.data,
    nrem_ripples[assembly_react.epochs[2]].starts,
    sampling_rate=assembly_act.fs,
    window=[-0.5, 0.5],
    return_average=True,
    return_pandas=True,
)

# round time index to 3 decimals for plotting
psth_swr_pre.index = np.round(psth_swr_pre.index, 3)
psth_swr_task.index = np.round(psth_swr_task.index, 3)
psth_swr_post.index = np.round(psth_swr_post.index, 3)

Section 4.4: Visualize reactivation dynamics during post-task ripples

Here, we have plotted Pre, Post, and Post subtracted by Pre to estimate the difference.

You can see that many of the assembles have a higher reactivation during the post-task ripples compared to the pre-task ripples.

fig, ax = plt.subplots(2, 3, figsize=(15, 8), sharey=False, sharex=False)
ax = ax.flatten()

# share y axis of first row
ax[0] = plt.subplot(231, sharey=ax[1])
ax[2] = plt.subplot(233, sharey=ax[0])

# plot assembly ripple psth
psth_swr_pre.plot(ax=ax[0], legend=False)
psth_swr_post.plot(ax=ax[1], legend=False)
(psth_swr_post - psth_swr_pre).plot(ax=ax[2])

# plot mean assembly ripple psth
psth_swr_pre.mean(axis=1).plot(ax=ax[0], color="k", legend=False)
psth_swr_post.mean(axis=1).plot(ax=ax[1], color="k", legend=False)
(psth_swr_post - psth_swr_pre).mean(axis=1).plot(ax=ax[2], color="k")

# plot assembly ripple psth heatmap
sns.heatmap(psth_swr_pre.T, ax=ax[3], cbar=False, vmin=0, vmax=5)
sns.heatmap(psth_swr_post.T, ax=ax[4], cbar=False, vmin=0, vmax=5)
sns.heatmap(
    (psth_swr_post - psth_swr_pre).T,
    ax=ax[5],
    cbar=False,
    vmin=-5,
    vmax=5,
    cmap="coolwarm",
)

for ax_ in ax[:3]:
    # dashed line at zero
    ax_.axvline(0, linestyle="--", color="k", linewidth=1)
    # set x axis limits
    ax_.set_xlim(-0.5, 0.5)
    # add grid lines
    ax_.grid()

ax[0].set_title("Pre")
ax[1].set_title("Post")
ax[2].set_title("Post - Pre")

# move legend
ax[2].legend(
    bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, frameon=False, title="assembly"
)

# add labels
ax[0].set_ylabel("assembly activity")
ax[3].set_ylabel("assembly")
ax[3].set_xlabel("time from SWR start (s)")

# clean axis using seaborn
sns.despine()

plt.show()

png


Section 5: Cross-Structural Assembly Detection

The AssemblyReact class now supports cross-structural assembly detection, which allows you to identify assemblies that span across different brain regions, cell types, or any other categorical grouping. This is particularly useful for studying cross-regional coordination.

Section 5.1: Load data from multiple brain regions

For this demonstration, we'll detect assemblies that span across CA1 and PFC regions.

basepath = r"U:\data\hpc_ctx_project\HP18\hp18_day40_20250514"
# load theta epochs
state_dict = loading.load_SleepState_states(basepath)
theta_epochs = nel.EpochArray(
    state_dict["THETA"],
)
nrem_epochs = nel.EpochArray(
    state_dict["NREMstate"],
)
ripples = npy.io.load_ripples_events(basepath, return_epoch_array=True)

epoch_df = loading.load_epoch(basepath)
epoch_df = npy.session.compress_repeated_epochs(epoch_df)
beh_epochs = nel.EpochArray(np.array([epoch_df.startTime, epoch_df.stopTime]).T)
pre_task_post = npy.session.find_multitask_pre_post(
    epoch_df.environment, post_sleep_flank=True, pre_sleep_common=True
)

# Load spike data from both CA1 and PFC regions
st, cell_metrics = loading.load_spikes(
    basepath, brainRegion="CA1|PFC", putativeCellType="Pyr"
)
brain_regions = np.array(["unknown"] * st.n_active)
brain_regions[cell_metrics.brainRegion.str.contains("CA1")] = "CA1"
brain_regions[cell_metrics.brainRegion.str.contains("PFC")] = "PFC"

# sort by brain region for easier visualization
idx = np.argsort(brain_regions)

st._data = st.data[idx]
brain_regions = brain_regions[idx]
cell_metrics = cell_metrics.iloc[idx].reset_index(drop=True)

Section 5.2: Standard vs Cross-Region Detection (cross_ica and cross_svd)

Let's compare three approaches:

  • Standard (method="ica" or "pca"): detects within-region and cross-region assemblies.
  • Cross-ICA (method="ica" + cross_structural): keeps only cross-region covariance structure and filters to multi-group assemblies.
  • Cross-SVD (method="cross_svd" + cross_structural): directly decomposes cross-area covariance with SVD, producing strictly bipartite CA1↔PFC assembly pairs.

For cross-structural analyses, significance is shuffle-based (bin) rather than Marčenko–Pastur.

Use these as practical starting points for cross-region analyses:

  • Fast exploration (quick sanity check)

  • method='ica' with cross_structural=...

  • nullhyp='bin', nshu=100, percentile=95

  • weight_dt=0.05, z_mat_dt=0.005

  • More robust inference (preferred for results)

  • method='ica' or method='cross_svd' with cross_structural=...

  • nullhyp='bin', nshu=500-1000, percentile=99-99.5

  • Keep weight_dt matched to your assembly timescale and z_mat_dt for activity resolution

  • Method-specific notes (both cross-regional methods)

  • Cross-ICA (method='ica' + cross_structural)

    • Uses cross-group covariance structure, then filters to keep assemblies active in at least two groups.

    • Filtering is controlled by cross_group_threshold_mode (absolute, relative, percentile) and its threshold parameters.

  • Cross-SVD (method='cross_svd')

    • Returns strictly bipartite two-group components.

    • Significance thresholding is controlled by cross_svd_threshold_mode:

    • per_rank (default): compares each singular value to its rank-matched null threshold.

    • max_stat: uses one global null threshold from shuffled max singular values (more conservative).

  • General guidance

  • Ensure both groups have enough active neurons.

  • If no assemblies are detected, first increase data duration, then adjust weight_dt, then increase nshu.

  • For conservative cross-SVD counting, prefer cross_svd_threshold_mode='max_stat'.

  • Speed note

  • Shuffle controls support parallel workers via n_jobs (for example n_jobs=-1 to use all cores).

  • Use FAST_DEMO=True for quick iteration, then rerun with larger nshu/higher percentile for final analyses.

Section 5.2.2: Mathematical Details (Cross-ICA and Cross-SVD)

Let $X \in \mathbb{R}^{N \times T}$ be the neuron-by-time matrix after binning and per-neuron z-scoring.

Let groups be indexed by $g \in {1,\dots,G}$ with group sizes $n_g$.

Cross-ICA path (method='ica' + cross_structural)

  1. Group-size normalization

$$

\tilde X_i(t) = \frac{X_i(t)}{\sqrt{n_{g(i)}}}

$$

where $g(i)$ is the group of neuron $i$. Note that since $X$ is already per-neuron z-scored, this normalization does not re-standardise individual neurons — it scales each group's collective contribution to the cross-group covariance by $1/n_g$.

  1. Block cross-group covariance matrix

$$

C_{ij}=

\begin{cases}

\frac{1}{T}\sum_{t=1}^T \tilde X_i(t)\tilde X_j(t), & g(i)\neq g(j)\

0, & g(i)=g(j)

\end{cases}

$$

So within-group blocks are zero and only cross-group blocks are retained.

  1. Subspace + ICA

Eigendecompose $C = Q\Lambda Q^\top$. Let $Q_k$ and $\Lambda_k$ denote the eigenvectors and eigenvalues corresponding to the $k$ significant dimensions (those exceeding the shuffle-based threshold). Project the group-normalized data into the significant subspace with eigenvalue weighting:

$$ P = \left(Q_k \odot \sqrt{\Lambda_k}\right)^\top \tilde{X} $$

where $\odot$ denotes column-wise scaling of $Q_k$ by $\sqrt{\Lambda_k}$. This whitens the projected dimensions by their coupling strength, so that ICA operates on variance-equalized components rather than raw projections whose scale reflects eigenvalue magnitude. ICA is then run on $P^\top$, and the resulting components are mapped back to the full neuron space and normalized to unit norm.

  1. Cross-group membership filter

  2. A pattern is retained only if it is active in at least two groups.

  3. Activity is determined by cross_group_threshold_mode and its threshold parameters.

Cross-SVD path (method='cross_svd')

For exactly two groups, split data into $X_1 \in \mathbb{R}^{N_1\times T}$ and $X_2 \in \mathbb{R}^{N_2\times T}$.

  1. Cross-area covariance

$$

M = \frac{1}{T}X_1X_2^\top

$$

  1. SVD decomposition

$$

M = U\Sigma V^\top

$$

where columns of $U$ and $V$ are group-specific spatial weights and diagonal entries of $\Sigma$ are coupling strengths.

  1. Shuffle-based significance

  2. Build a null distribution by independently permuting time indices in each group.

  3. Compare observed singular values to null using cross_svd_threshold_mode:

    • per_rank: $\sigma_k > q_k$ (rank-wise null threshold).

    • max_stat: $\sigma_k > q_{\max}$ (single global threshold; stricter).

  4. Time-resolved cross-area activity

For component $k$, compute projections of each group onto their respective spatial weights, z-score each projection across time, and take the element-wise product:

$$

a_k(t)=z!\left(\mathbf{u}_k^\top X_1\right)(t)\cdot z!\left(\mathbf{v}_k^\top X_2\right)(t)

$$

Z-scoring centres each projection (removing the positive bias that a raw dot product accumulates even under no assembly activity) and normalises scale across components. Positive values indicate synchronous co-expression; negative values indicate anti-coactivation. Note that if a projection is constant across time (e.g. due to a silent group), its z-score is undefined and is set to zero, silencing that component's coactivation score.

A short runnable comparison of per_rank vs max_stat is shown in the next cell.

# Toggle for speed vs robustness
FAST_DEMO = True
if FAST_DEMO:
    nshu_demo = 100
    percentile_demo = 95
else:
    nshu_demo = 1000
    percentile_demo = 99

# Standard assembly detection (finds both within-region and cross-region assemblies)
assembly_react_standard = AssemblyReact(weight_dt=0.05, z_mat_dt=0.005)
assembly_react_standard.add_st(st)
assembly_react_standard.epochs = beh_epochs
assembly_react_standard.get_weights(
    epoch=assembly_react_standard.epochs[pre_task_post[0][1].item()] & theta_epochs
)

print(f"Standard detection found {assembly_react_standard.n_assemblies()} assemblies")

# Cross-ICA assembly detection (cross-structural covariance + multi-group filtering)
assembly_react_cross = AssemblyReact(
    method="ica",
    weight_dt=0.05,
    z_mat_dt=0.005,
    cross_structural=brain_regions,
    nullhyp="bin",
    nshu=nshu_demo,
    percentile=percentile_demo,
    n_jobs=-1,
)
assembly_react_cross.add_st(st)
assembly_react_cross.epochs = beh_epochs
assembly_react_cross.get_weights(
    epoch=assembly_react_cross.epochs[pre_task_post[0][1].item()] & theta_epochs
)

print(f"Cross-ICA detection found {assembly_react_cross.n_assemblies()} assemblies")

# Cross-SVD assembly detection (strictly bipartite cross-area assemblies)
assembly_react_cross_svd = AssemblyReact(
    method="cross_svd",
    weight_dt=0.05,
    z_mat_dt=0.005,
    cross_structural=brain_regions,
    nullhyp="bin",
    nshu=nshu_demo,
    percentile=percentile_demo,
    n_jobs=-1,
)
assembly_react_cross_svd.add_st(st)
assembly_react_cross_svd.epochs = beh_epochs
assembly_react_cross_svd.get_weights(
    epoch=assembly_react_cross_svd.epochs[pre_task_post[0][1].item()] & theta_epochs
)

print(f"Cross-SVD detection found {assembly_react_cross_svd.n_assemblies()} assemblies")

print(f"(Used nshu={nshu_demo}, percentile={percentile_demo})")
Standard detection found 37 assemblies
Cross-ICA detection found 9 assemblies
Cross-SVD detection found 62 assemblies
(Used nshu=100, percentile=95)
# Cross-SVD significance mode comparison: per_rank vs max_stat
epoch_demo = beh_epochs[pre_task_post[0][1].item()] & theta_epochs

assembly_react_cross_svd_per_rank = AssemblyReact(
    method="cross_svd",
    weight_dt=0.05,
    z_mat_dt=0.005,
    cross_structural=brain_regions,
    nullhyp="bin",
    nshu=nshu_demo,
    percentile=percentile_demo,
    cross_svd_threshold_mode="per_rank",
    random_state=42,
    n_jobs=-1,
)
assembly_react_cross_svd_per_rank.add_st(st)
assembly_react_cross_svd_per_rank.epochs = beh_epochs
assembly_react_cross_svd_per_rank.get_weights(epoch=epoch_demo)

assembly_react_cross_svd_max_stat = AssemblyReact(
    method="cross_svd",
    weight_dt=0.05,
    z_mat_dt=0.005,
    cross_structural=brain_regions,
    nullhyp="bin",
    nshu=nshu_demo,
    percentile=percentile_demo,
    cross_svd_threshold_mode="max_stat",
    random_state=42,
    n_jobs=-1,
)
assembly_react_cross_svd_max_stat.add_st(st)
assembly_react_cross_svd_max_stat.epochs = beh_epochs
assembly_react_cross_svd_max_stat.get_weights(epoch=epoch_demo)

n_per_rank = assembly_react_cross_svd_per_rank.n_assemblies()
n_max_stat = assembly_react_cross_svd_max_stat.n_assemblies()

print("Cross-SVD threshold mode comparison")
print(f"  per_rank: {n_per_rank} assemblies")
print(f"  max_stat: {n_max_stat} assemblies")
print(f"  reduction: {n_per_rank - n_max_stat}")
Cross-SVD threshold mode comparison
  per_rank: 62 assemblies
  max_stat: 9 assemblies
  reduction: 53

Section 5.3: Visualize Cross-Structural Assembly Weights

The cross-structural assemblies should show weights in both CA1 and PFC regions. The vertical red line separates the two brain regions.

if assembly_react_cross.n_assemblies() > 0:
    # Plot cross-structural assembly weights
    fig, axes = assembly_react_cross.plot(figsize=(12, 6))

    # Add vertical line to separate brain regions
    ca1_neurons = np.sum(brain_regions == "CA1")
    for ax in axes.flat:
        ax.axhline(
            ca1_neurons - 0.5, color="red", linestyle="--", alpha=0.7, linewidth=2
        )
        ax.text(
            0.02,
            1.05,
            f"PFC (top)\nCA1 (bottom)",
            transform=ax.transAxes,
            va="top",
            ha="left",
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
        )

    plt.suptitle("Cross-Structural Assemblies (CA1-PFC)", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()

    # Analyze which regions participate in each assembly
    assembly_react_cross.find_members()

    print("\n Cross-structural assembly analysis:")
    for i in range(assembly_react_cross.n_assemblies()):
        ca1_active = assembly_react_cross.assembly_members[i, :ca1_neurons].sum()
        pfc_active = assembly_react_cross.assembly_members[i, ca1_neurons:].sum()
        print(
            f"  Assembly {i + 1}: CA1 neurons: {ca1_active}, PFC neurons: {pfc_active}"
        )

else:
    print("No cross-structural assemblies detected in this dataset.")
    print("This could mean:")
    print("1. There are no assemblies spanning both regions")
    print("2. The assemblies are primarily within-region")
    print("3. More data or different parameters may be needed")

png

 Cross-structural assembly analysis:
  Assembly 1: CA1 neurons: 15, PFC neurons: 17
  Assembly 2: CA1 neurons: 25, PFC neurons: 19
  Assembly 3: CA1 neurons: 20, PFC neurons: 21
  Assembly 4: CA1 neurons: 8, PFC neurons: 9
  Assembly 5: CA1 neurons: 13, PFC neurons: 14
  Assembly 6: CA1 neurons: 16, PFC neurons: 18
  Assembly 7: CA1 neurons: 3, PFC neurons: 1
  Assembly 8: CA1 neurons: 26, PFC neurons: 18
  Assembly 9: CA1 neurons: 25, PFC neurons: 19

Section 5.4: Compare Assembly Activity Across Methods

Now compare post-task ripple activity for three methods:

  1. Standard assemblies

  2. Cross-ICA assemblies

  3. Cross-SVD assemblies

For cross_svd, activity reflects time-resolved cross-area coactivation computed as

z-scored CA1 projection × z-scored PFC projection for each component, which directly targets inter-regional coupling while keeping the score centered.

assembly_act_standard = assembly_react_standard.get_assembly_act(
    epoch=assembly_react_standard.epochs[pre_task_post[0][2].item()]
)

# Compute assembly activity for cross-ICA and cross-SVD assemblies
assembly_act_cross = assembly_react_cross.get_assembly_act(
    epoch=assembly_react_cross.epochs[pre_task_post[0][2].item()]
)
assembly_act_cross_svd = assembly_react_cross_svd.get_assembly_act(
    epoch=assembly_react_cross_svd.epochs[pre_task_post[0][2].item()]
)

nrem_ripples = ripples & nrem_epochs

psth_standard_post = npy.process.event_triggered_average(
    timestamps=assembly_act_standard.abscissa_vals,
    signal=assembly_act_standard.data.T,
    events=nrem_ripples[
        assembly_react_standard.epochs[pre_task_post[0][2].item()]
    ].starts,
    sampling_rate=assembly_act_standard.fs,
    window=[-0.5, 0.5],
    return_average=True,
    return_pandas=True,
)

if assembly_react_cross.n_assemblies() > 0:
    psth_cross_post = npy.process.event_triggered_average(
        timestamps=assembly_act_cross.abscissa_vals,
        signal=assembly_act_cross.data.T,
        events=nrem_ripples[
            assembly_react_cross.epochs[pre_task_post[0][2].item()]
        ].starts,
        sampling_rate=assembly_act_cross.fs,
        window=[-0.5, 0.5],
        return_average=True,
        return_pandas=True,
    )
else:
    psth_cross_post = None

if assembly_react_cross_svd.n_assemblies() > 0:
    psth_cross_svd_post = npy.process.event_triggered_average(
        timestamps=assembly_act_cross_svd.abscissa_vals,
        signal=assembly_act_cross_svd.data.T,
        events=nrem_ripples[
            assembly_react_cross_svd.epochs[pre_task_post[0][2].item()]
        ].starts,
        sampling_rate=assembly_act_cross_svd.fs,
        window=[-0.5, 0.5],
        return_average=True,
        return_pandas=True,
    )
else:
    psth_cross_svd_post = None

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=False)

# Plot standard assemblies
psth_standard_post.plot(ax=axes[0], legend=False)
axes[0].set_ylabel("Assembly Activity")
axes[0].axvline(0, linestyle="--", color="k", alpha=0.7)
axes[0].grid(True, alpha=0.3)
axes[0].set_title(f"Standard (n={assembly_react_standard.n_assemblies()})")

# Plot cross-ICA assemblies
if psth_cross_post is not None:
    psth_cross_post.plot(ax=axes[1], legend=False)
    axes[1].set_title(f"Cross-ICA (n={assembly_react_cross.n_assemblies()})")
else:
    axes[1].text(0.5, 0.5, "No cross-ICA assemblies", ha="center", va="center")
    axes[1].set_title("Cross-ICA (n=0)")
axes[1].axvline(0, linestyle="--", color="k", alpha=0.7)
axes[1].grid(True, alpha=0.3)

# Plot cross-SVD assemblies
if psth_cross_svd_post is not None:
    psth_cross_svd_post.plot(ax=axes[2], legend=False)
    axes[2].set_title(f"Cross-SVD (n={assembly_react_cross_svd.n_assemblies()})")
else:
    axes[2].text(0.5, 0.5, "No cross-SVD assemblies", ha="center", va="center")
    axes[2].set_title("Cross-SVD (n=0)")
axes[2].axvline(0, linestyle="--", color="k", alpha=0.7)
axes[2].grid(True, alpha=0.3)

for ax in axes:
    ax.legend().set_visible(False)

plt.suptitle("Assembly Activity During Post-Task Ripples", fontsize=14)
plt.tight_layout()
sns.despine()
plt.show()

# Compare peak activation
standard_peak = psth_standard_post.max().mean()
cross_peak = psth_cross_post.max().mean() if psth_cross_post is not None else np.nan
cross_svd_peak = (
    psth_cross_svd_post.max().mean() if psth_cross_svd_post is not None else np.nan
)

print("Peak activation comparison:")
print(f"Standard assemblies (mean): {standard_peak:.2f}")
print(f"Cross-ICA assemblies (mean): {cross_peak:.2f}")
print(f"Cross-SVD assemblies (mean): {cross_svd_peak:.2f}")

png

Peak activation comparison:
Standard assemblies (mean): 3.52
Cross-ICA assemblies (mean): 7.22
Cross-SVD assemblies (mean): 0.15

Section 5.5: Visualizing the Cross-Region Matrix Used for Detection

For cross-structural detection, the algorithm uses a group-normalized, block-structured cross-region covariance matrix:

$$

C = \begin{bmatrix}

0 & C_{AB} \

C_{BA} & 0

\end{bmatrix}

$$

where within-region blocks are zero and only cross-region covariance terms are retained.

The visualization below reproduces this matrix directly from the same z-scored activity used for assembly detection, so it is aligned with the actual model input rather than a generic masked map.

# Visualize the cross-region matrix used by cross-structural detection


# Recompute z-scored binned spike matrix using AssemblyReact helper
zmat, _ = assembly_react_cross.get_z_mat(
    assembly_react_cross.st[
        assembly_react_cross.epochs[pre_task_post[0][1].item()] & theta_epochs
    ]
)

groups = np.asarray(brain_regions)
unique_groups = np.unique(groups)

# Group-size normalization: divide each group's rows by sqrt(group size)
zmat_norm = zmat.copy().astype(float)
for group in unique_groups:
    mask = groups == group
    zmat_norm[mask, :] /= np.sqrt(mask.sum())

# Build explicit block matrix with only cross-group covariance
cross_matrix = np.zeros((zmat_norm.shape[0], zmat_norm.shape[0]), dtype=float)

for i, group_a in enumerate(unique_groups):
    idx_a = np.where(groups == group_a)[0]
    Xa = zmat_norm[idx_a, :]

    for group_b in unique_groups[i + 1 :]:
        idx_b = np.where(groups == group_b)[0]

        Xb = zmat_norm[idx_b, :]

        cov_ab = Xa @ Xb.T / Xa.shape[1]

        cross_matrix[np.ix_(idx_a, idx_b)] = cov_ab

        cross_matrix[np.ix_(idx_b, idx_a)] = cov_ab.T


# Use a robust data-driven color scale; the absolute values are often much
# smaller than 0.02 after z-scoring and group-size normalization.
nonzero_vals = cross_matrix[~np.isclose(cross_matrix, 0.0)]
if nonzero_vals.size:
    vmax = np.percentile(np.abs(nonzero_vals), 99)
    vmax = max(vmax, np.max(np.abs(nonzero_vals)) / 5)
else:
    vmax = 1e-6


print(
    f"Cross-matrix summary: min={cross_matrix.min():.3e}, "
    f"max={cross_matrix.max():.3e}, vmax_used={vmax:.3e}, "
    f"nonzero_entries={nonzero_vals.size}"
)


# Mask within-group blocks to white so cross-region structure stands out.
within_group_mask = groups[:, None] == groups[None, :]
cmap = plt.get_cmap("coolwarm").copy()
cmap.set_bad(color="white")
cross_matrix_masked = np.ma.masked_where(within_group_mask, cross_matrix)

fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(cross_matrix_masked, cmap=cmap, vmin=-vmax, vmax=vmax)


# Add lines to separate regions
ca1_neurons = np.sum(brain_regions == "CA1")
ax.axhline(ca1_neurons - 0.5, color="red", linestyle="--", linewidth=2, alpha=0.7)
ax.axvline(ca1_neurons - 0.5, color="red", linestyle="--", linewidth=2, alpha=0.7)


# Axis labels and ticks
ax.set_xlabel("Neuron (sorted by region)")
ax.set_ylabel("Neuron (sorted by region)")
ax.set_xticks([0, ca1_neurons, len(brain_regions)])
ax.set_yticks([0, ca1_neurons, len(brain_regions)])
ax.set_xticklabels(["0", f"{ca1_neurons}", f"{len(brain_regions)}"])
ax.set_yticklabels(["0", f"{ca1_neurons}", f"{len(brain_regions)}"])

# Add region labels
ax.text(
    ca1_neurons / 2,
    -10,
    "CA1",
    ha="center",
    va="bottom",
    fontsize=14,
    fontweight="bold",
    color="black",
    clip_on=False,
)

ax.text(
    ca1_neurons + (len(brain_regions) - ca1_neurons) / 2,
    -10,
    "PFC",
    ha="center",
    va="bottom",
    fontsize=14,
    fontweight="bold",
    color="black",
    clip_on=False,
)

fig.suptitle(
    "Cross-Region Block Covariance Matrix Used for Detection",
    fontsize=12,
    y=0.94,
)

cbar = plt.colorbar(im, ax=ax, label="Cross-region covariance", pad=0.02)

plt.tight_layout(rect=[0, 0, 1, 0.93])

plt.show()
Cross-matrix summary: min=-4.801e-04, max=6.484e-04, vmax_used=2.742e-04, nonzero_entries=20634

png

Section 5.6: Key Points About Cross-Region Assembly Detection

What it does: - Identifies assemblies that span across different groups (regions, cell types, etc.) - Filters out assemblies that are confined to a single group - Enables study of cross-regional coordination and communication

Methods: - method='ica' + cross_structural: cross-ICA (cross-covariance constrained ICA) - method='cross_svd' + cross_structural: cross-SVD (strictly bipartite cross-area assemblies)

Parameters: - cross_structural: categorical array with one label per neuron - method='cross_svd' requires exactly two groups - Cross-structural analyses use shuffle-based significance (nullhyp='bin')

Important considerations: - Requires sufficient neurons from multiple groups - May detect fewer assemblies than standard detection (by design) - cross_svd activity is interpreted as cross-area coactivation over time

Take-home: Use cross_ica when you want flexible cross-group assemblies with ICA structure; use cross_svd when you want the cleanest, strictly two-group cross-area coupling representation.


Section 5.7: Simulation — Why does cross-SVD detect so many assemblies?

This simulation generates synthetic data with a known number of embedded cross-area assemblies (3), then shows:

  1. Why cross-SVD with loose settings (low nshu, low percentile) over-detects
  2. How the singular value spectrum compares to the full null distribution
  3. What parameter choices bring detection in line with the true number
from neuro_py.ensemble.assembly import _compute_cross_svd, _cross_svd_significance

np.random.seed(42)

# ── Simulation parameters ──────────────────────────────────────────────────────
n_ca1 = 56  # match your real data: min(n_CA1, n_PFC) → max testable components
n_pfc = 56
T = 2000  # time bins (same order of magnitude as task epoch)
N_true = 3  # ground-truth cross-area assemblies embedded in the data
noise_std = 1.0  # background noise level
signal_amp = 3.0  # extra amplitude for assembly events (above noise)
rate_event = 0.05  # fraction of bins that are assembly events

# ── Generate background noise ──────────────────────────────────────────────────
X_ca1 = np.random.randn(n_ca1, T) * noise_std
X_pfc = np.random.randn(n_pfc, T) * noise_std

# ── Embed N_true cross-area assemblies ─────────────────────────────────────────
# Each assembly: random unit-norm vectors in CA1 and PFC space
true_u = np.random.randn(n_ca1, N_true)
true_u /= np.linalg.norm(true_u, axis=0)
true_v = np.random.randn(n_pfc, N_true)
true_v /= np.linalg.norm(true_v, axis=0)

for k in range(N_true):
    event_bins = np.where(np.random.rand(T) < rate_event)[0]
    X_ca1[:, event_bins] += signal_amp * true_u[:, k : k + 1]
    X_pfc[:, event_bins] += signal_amp * true_v[:, k : k + 1]

# z-score (replicates what runPatterns does before calling SVD)
from scipy import stats as scipy_stats

X_ca1_z = scipy_stats.zscore(X_ca1, axis=1)
X_pfc_z = scipy_stats.zscore(X_pfc, axis=1)
zactmat_sim = np.vstack([X_ca1_z, X_pfc_z])
labels_sim = np.array(["CA1"] * n_ca1 + ["PFC"] * n_pfc)

print(f"Simulation: {n_ca1} CA1 neurons, {n_pfc} PFC neurons, {T} time bins")
print(f"Ground truth: {N_true} embedded cross-area assemblies")
print(f"Testable SVD components: {min(n_ca1, n_pfc)} (= min(n_CA1, n_PFC))")
Simulation: 56 CA1 neurons, 56 PFC neurons, 2000 time bins
Ground truth: 3 embedded cross-area assemblies
Testable SVD components: 56 (= min(n_CA1, n_PFC))

Section 5.7.1: Singular value spectrum vs null distribution

The key plot: actual singular values (sorted) overlaid on the null distribution from shuffles. With 56 testable components and a loose threshold, almost every component can exceed the null.

nshu_test = 100  # same as FAST_DEMO
percentile_test = 95

U, S, Vt, keep, null_thresholds, _, _ = _cross_svd_significance(
    zactmat_sim,
    labels_sim,
    nshu=nshu_test,
    percentile=percentile_test,
)

n_detected = keep.sum()
component_idx = np.arange(1, len(S) + 1)

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# ── Left: spectrum vs null threshold ──────────────────────────────────────────

ax = axes[0]

ax.plot(component_idx, S, "o-", color="steelblue", label="Actual singular values", ms=4)

ax.plot(
    component_idx,
    null_thresholds,
    "r--",
    label=f"Null {percentile_test}th pct (nshu={nshu_test})",
    lw=1.5,
)

ax.axvline(
    N_true + 0.5,
    color="green",
    linestyle=":",
    lw=2,
    label=f"True N assemblies ({N_true})",
)

ax.fill_between(
    component_idx,
    null_thresholds,
    S,
    where=(S > null_thresholds),
    alpha=0.25,
    color="steelblue",
    label=f"Detected: {n_detected}",
)

ax.set_xlabel("SVD component rank")
ax.set_ylabel("Singular value")

ax.set_title(
    f"nshu={nshu_test}, pct={percentile_test}{n_detected}/{len(S)} detected"
)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)


# ── Right: histogram of null SV for component 1 (signal) and component 20 (noise) ──
ax = axes[1]

# Recompute once to get component indices for each group
_, S_real, _, _, _, idx_g1, idx_g2 = _cross_svd_significance(
    zactmat_sim, labels_sim, nshu=500, percentile=percentile_test
)


# Build a fuller null histogram using the same symmetric group-level shuffle

X1 = zactmat_sim[idx_g1, :]
X2 = zactmat_sim[idx_g2, :]

null_s = []

rng = np.random.default_rng(0)

T_sim = X1.shape[1]

for _ in range(500):
    pi1 = rng.permutation(T_sim)
    pi2 = rng.permutation(T_sim)

    cc = X1[:, pi1] @ X2[:, pi2].T / T_sim
    _, sv, _ = np.linalg.svd(cc, full_matrices=False)

    null_s.append(sv)

null_s = np.array(null_s)  # (500, n_components)


ax.hist(
    null_s[:, 0],
    bins=30,
    alpha=0.6,
    color="salmon",
    label="Null SV rank-1 (signal comp)",
    density=True,
)

ax.hist(
    null_s[:, 19],
    bins=30,
    alpha=0.6,
    color="lightblue",
    label="Null SV rank-20 (noise comp)",
    density=True,
)

ax.axvline(
    S_real[0],
    color="darkred",
    lw=2,
    linestyle="-",
    label=f"Actual SV rank-1 = {S_real[0]:.3f}",
)

ax.axvline(
    S_real[19],
    color="steelblue",
    lw=2,
    linestyle="-",
    label=f"Actual SV rank-20 = {S_real[19]:.3f}",
)

ax.set_xlabel("Singular value")

ax.set_ylabel("Density")

ax.set_title("Null distribution comparison: signal vs noise component")

ax.legend(fontsize=8)

ax.grid(True, alpha=0.3)

plt.suptitle("Cross-SVD: Why are so many components detected?", fontsize=13)
plt.tight_layout()
sns.despine()
plt.show()


print(f"\nTrue embedded assemblies : {N_true}")
print(f"Detected (nshu={nshu_test}, pct={percentile_test}) : {n_detected}")
print(f"Total testable components : {len(S)}")
print(f"Note: with {percentile_test}th percentile and only {nshu_test} shuffles,")
print(
    f"      the null threshold for each rank is estimated from only "
    f"~{int((1-percentile_test/100)*nshu_test)} samples — very noisy."
)

png

True embedded assemblies : 3
Detected (nshu=100, pct=95) : 8
Total testable components : 56
Note: with 95th percentile and only 100 shuffles,
      the null threshold for each rank is estimated from only ~5 samples — very noisy.

Section 5.7.2: Parameter sweep — nshu and percentile

Sweep both axes to find what settings recover the true N=3. The heatmap shows detected assembly count; the green cell is closest to ground truth.

nshu_grid = [50, 100, 200, 500, 1000]
percentile_grid = [90, 95, 99, 99.5]

results = np.zeros((len(percentile_grid), len(nshu_grid)), dtype=int)

for pi, pct in enumerate(percentile_grid):
    for ni, nshu in enumerate(nshu_grid):
        _, _, _, keep_i, _, _, _ = _cross_svd_significance(
            zactmat_sim, labels_sim, nshu=nshu, percentile=pct
        )
        results[pi, ni] = keep_i.sum()
        print(
            f"  nshu={nshu:5d}, pct={pct:5.1f}{keep_i.sum()} detected", flush=True
        )

fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(results, cmap="RdYlGn_r", aspect="auto", vmin=0, vmax=min(n_ca1, n_pfc))
plt.colorbar(im, ax=ax, label="# detected assemblies")

# annotate cells
for pi in range(len(percentile_grid)):
    for ni in range(len(nshu_grid)):
        val = results[pi, ni]
        color = "white" if val > 40 else "black"
        ax.text(
            ni,
            pi,
            str(val),
            ha="center",
            va="center",
            color=color,
            fontsize=10,
            fontweight="bold",
        )

# green border around cells that match ground truth ±1
for pi in range(len(percentile_grid)):
    for ni in range(len(nshu_grid)):
        if abs(results[pi, ni] - N_true) <= 1:
            rect = plt.Rectangle(
                (ni - 0.5, pi - 0.5), 1, 1, fill=False, edgecolor="lime", lw=3
            )
            ax.add_patch(rect)

ax.set_xticks(range(len(nshu_grid)))
ax.set_xticklabels(nshu_grid)
ax.set_yticks(range(len(percentile_grid)))
ax.set_yticklabels(percentile_grid)
ax.set_xlabel("nshu")
ax.set_ylabel("percentile")
ax.set_title(
    f"Detected assemblies (ground truth = {N_true}, green = within ±1)\n"
    f"Data: {n_ca1} CA1 x {n_pfc} PFC neurons, {T} bins"
)

plt.tight_layout()
plt.show()

print(f"\nTake-home: with {min(n_ca1,n_pfc)} testable components and loose settings,")
print("almost all singular values exceed a noisy null threshold.")
print("Only high nshu + high percentile recovers the true assembly count.")
  nshu=   50, pct= 90.0  →  9 detected
  nshu=  100, pct= 90.0  →  9 detected
  nshu=  200, pct= 90.0  →  9 detected
  nshu=  500, pct= 90.0  →  9 detected
  nshu= 1000, pct= 90.0  →  9 detected
  nshu=   50, pct= 95.0  →  8 detected
  nshu=  100, pct= 95.0  →  9 detected
  nshu=  200, pct= 95.0  →  7 detected
  nshu=  500, pct= 95.0  →  8 detected
  nshu= 1000, pct= 95.0  →  7 detected
  nshu=   50, pct= 99.0  →  4 detected
  nshu=  100, pct= 99.0  →  5 detected
  nshu=  200, pct= 99.0  →  6 detected
  nshu=  500, pct= 99.0  →  6 detected
  nshu= 1000, pct= 99.0  →  6 detected
  nshu=   50, pct= 99.5  →  5 detected
  nshu=  100, pct= 99.5  →  6 detected
  nshu=  200, pct= 99.5  →  6 detected
  nshu=  500, pct= 99.5  →  5 detected
  nshu= 1000, pct= 99.5  →  5 detected

png

Take-home: with 56 testable components and loose settings,
almost all singular values exceed a noisy null threshold.
Only high nshu + high percentile recovers the true assembly count.