Skip to content

assembly_reactivation

AssemblyReact

Class for running assembly reactivation analysis

Core assembly methods come from assembly.py by Vítor Lopes dos Santos https://doi.org/10.1016/j.jneumeth.2013.04.010

Parameters:

Name Type Description Default
basepath str

Path to the session folder

None
brainRegion str

Brain region to restrict to. Can be multi ex. "CA1|CA2"

'CA1'
putativeCellType str

Cell type to restrict to

'Pyramidal Cell'
weight_dt float

Time resolution of the weight matrix

0.025
z_mat_dt float

Time resolution of the z matrix

0.002
method str

Defines how to extract assembly patterns (ica,pca).

'ica'
nullhyp str

Defines how to generate statistical threshold for assembly detection (bin,circ,mp).

'mp'
nshu int

Number of shuffles for bin and circ null hypothesis.

1000
percentile int

Percentile for mp null hypothesis.

99
tracywidom bool

If true, uses Tracy-Widom distribution for mp null hypothesis.

False

Attributes:

Name Type Description
st SpikeTrainArray

Spike train

cell_metrics DataFrame

Cell metrics

ripples EpochArray

Ripples

patterns ndarray

Assembly patterns

assembly_act AnalogSignalArray

Assembly activity

Methods:

Name Description
load_data

Load data (st, ripples, epochs)

restrict_to_epoch

Restrict to a specific epoch

get_z_mat

Get z matrix

get_weights

Get assembly weights

get_assembly_act

Get assembly activity

n_assemblies

Number of detected assemblies

isempty

Check if empty

copy

Returns copy of class

plot

Stem plot of assembly weights

find_members

Find members of an assembly

Examples:

>>> # create the object assembly_react
>>> assembly_react = assembly_reactivation.AssemblyReact(
...    basepath=basepath,
...    )
>>> # load need data (spikes, ripples, epochs)
>>> assembly_react.load_data()
>>> # detect assemblies
>>> assembly_react.get_weights()
>>> # visually inspect weights for each assembly
>>> assembly_react.plot()
>>> # compute time resolved signal for each assembly
>>> assembly_act = assembly_react.get_assembly_act()
>>> # locate members of assemblies
>>> assembly_members = assembly_react.find_members()
Source code in neuro_py/ensemble/assembly_reactivation.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
class AssemblyReact:
    """
    Class for running assembly reactivation analysis

    Core assembly methods come from assembly.py by Vítor Lopes dos Santos
        https://doi.org/10.1016/j.jneumeth.2013.04.010

    Parameters
    ----------
    basepath : str
        Path to the session folder
    brainRegion : str
        Brain region to restrict to. Can be multi ex. "CA1|CA2"
    putativeCellType : str
        Cell type to restrict to
    weight_dt : float
        Time resolution of the weight matrix
    z_mat_dt : float
        Time resolution of the z matrix
    method : str
        Defines how to extract assembly patterns (ica,pca).
    nullhyp : str
        Defines how to generate statistical threshold for assembly detection (bin,circ,mp).
    nshu : int
        Number of shuffles for bin and circ null hypothesis.
    percentile : int
        Percentile for mp null hypothesis.
    tracywidom : bool
        If true, uses Tracy-Widom distribution for mp null hypothesis.

    Attributes
    ----------
    st : nelpy.SpikeTrainArray
        Spike train
    cell_metrics : pd.DataFrame
        Cell metrics
    ripples : nelpy.EpochArray
        Ripples
    patterns : np.ndarray
        Assembly patterns
    assembly_act : nelpy.AnalogSignalArray
        Assembly activity

    Methods
    -------
    load_data()
        Load data (st, ripples, epochs)
    restrict_to_epoch(epoch)
        Restrict to a specific epoch
    get_z_mat(st)
        Get z matrix
    get_weights(epoch=None)
        Get assembly weights
    get_assembly_act(epoch=None)
        Get assembly activity
    n_assemblies()
        Number of detected assemblies
    isempty()
        Check if empty
    copy()
        Returns copy of class
    plot()
        Stem plot of assembly weights
    find_members()
        Find members of an assembly


    Examples
    --------
    >>> # create the object assembly_react
    >>> assembly_react = assembly_reactivation.AssemblyReact(
    ...    basepath=basepath,
    ...    )

    >>> # load need data (spikes, ripples, epochs)
    >>> assembly_react.load_data()

    >>> # detect assemblies
    >>> assembly_react.get_weights()

    >>> # visually inspect weights for each assembly
    >>> assembly_react.plot()

    >>> # compute time resolved signal for each assembly
    >>> assembly_act = assembly_react.get_assembly_act()

    >>> # locate members of assemblies
    >>> assembly_members = assembly_react.find_members()

    """

    def __init__(
        self,
        basepath: Union[str, None] = None,
        brainRegion: str = "CA1",
        putativeCellType: str = "Pyramidal Cell",
        weight_dt: float = 0.025,
        z_mat_dt: float = 0.002,
        method: str = "ica",
        nullhyp: str = "mp",
        nshu: int = 1000,
        percentile: int = 99,
        tracywidom: bool = False,
        whiten: str = "unit-variance",
    ):
        self.basepath = basepath
        self.brainRegion = brainRegion
        self.putativeCellType = putativeCellType
        self.weight_dt = weight_dt
        self.z_mat_dt = z_mat_dt
        self.method = method
        self.nullhyp = nullhyp
        self.nshu = nshu
        self.percentile = percentile
        self.tracywidom = tracywidom
        self.whiten = whiten
        self.type_name = self.__class__.__name__

    def add_st(self, st: nel.SpikeTrainArray) -> None:
        self.st = st

    def add_ripples(self, ripples: nel.EpochArray) -> None:
        self.ripples = ripples

    def add_epoch_df(self, epoch_df: pd.DataFrame) -> None:
        self.epoch_df = epoch_df

    def load_spikes(self) -> None:
        """
        loads spikes from the session folder
        """
        self.st, self.cell_metrics = loading.load_spikes(
            self.basepath,
            brainRegion=self.brainRegion,
            putativeCellType=self.putativeCellType,
            support=self.time_support,
        )

    def load_ripples(self) -> None:
        """
        loads ripples from the session folder
        """
        ripples = loading.load_ripples_events(self.basepath)
        self.ripples = nel.EpochArray(
            [np.array([ripples.start, ripples.stop]).T], domain=self.time_support
        )

    def load_epoch(self) -> None:
        """
        loads epochs from the session folder
        """
        epoch_df = loading.load_epoch(self.basepath)
        epoch_df = compress_repeated_epochs(epoch_df)
        self.time_support = nel.EpochArray(
            [epoch_df.iloc[0].startTime, epoch_df.iloc[-1].stopTime]
        )
        self.epochs = nel.EpochArray(
            [np.array([epoch_df.startTime, epoch_df.stopTime]).T],
            domain=self.time_support,
        )
        self.epoch_df = epoch_df

    def load_data(self) -> None:
        """
        loads data (spikes,ripples,epochs) from the session folder
        """
        self.load_epoch()
        self.load_spikes()
        self.load_ripples()

    def restrict_epochs_to_pre_task_post(self) -> None:
        """
        Restricts the epochs to the specified epochs
        """
        # fetch data
        epoch_df = loading.load_epoch(self.basepath)
        # compress back to back sleep epochs (an issue further up the pipeline)
        epoch_df = compress_repeated_epochs(epoch_df)
        # restrict to pre task post epochs
        idx = find_pre_task_post(epoch_df.environment)
        self.epoch_df = epoch_df[idx[0]]
        # convert to epoch array and add to object
        self.epochs = nel.EpochArray(
            [np.array([self.epoch_df.startTime, self.epoch_df.stopTime]).T],
            label="session_epochs",
            domain=self.time_support,
        )

    def restrict_to_epoch(self, epoch) -> None:
        """
        Restricts the spike data to a specific epoch.

        Parameters
        ----------
        epoch : nel.EpochArray
            The epoch to restrict to.
        """
        self.st_resticted = self.st[epoch]

    def get_z_mat(self, st: nel.SpikeTrainArray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Get z matrix.

        Parameters
        ----------
        st : nel.SpikeTrainArray
            Spike train array.

        Returns
        -------
        Tuple[np.ndarray, np.ndarray]
            Z-scored binned spike train and bin centers.
        """
        # binning the spike train
        z_t = st.bin(ds=self.z_mat_dt)
        # gaussian kernel to match the bin-size used to identify the assembly patterns
        sigma = self.weight_dt / np.sqrt(int(1000 * self.weight_dt / 2))
        z_t.smooth(sigma=sigma, inplace=True)
        # zscore the z matrix
        z_scored_bst = stats.zscore(z_t.data, axis=1)
        # make sure there are no nans, important as strengths will all be nan otherwise
        z_scored_bst[np.isnan(z_scored_bst).any(axis=1)] = 0

        return z_scored_bst, z_t.bin_centers

    def get_weights(self, epoch: Optional[nel.EpochArray] = None) -> None:
        """
        Gets the assembly weights.

        Parameters
        ----------
        epoch : nel.EpochArray, optional
            The epoch to restrict to, by default None.
        """

        # check if st has any neurons
        if self.st.isempty:
            self.patterns = None
            return

        if epoch is not None:
            bst = self.st[epoch].bin(ds=self.weight_dt).data
        else:
            bst = self.st.bin(ds=self.weight_dt).data

        if (bst == 0).all():
            self.patterns = None
        else:
            patterns, _, _ = assembly.runPatterns(
                bst,
                method=self.method,
                nullhyp=self.nullhyp,
                nshu=self.nshu,
                percentile=self.percentile,
                tracywidom=self.tracywidom,
                whiten=self.whiten,
            )
            # flip patterns to have positive max
            self.patterns = np.array(
                [
                    (
                        patterns[i, :]
                        if patterns[i, np.argmax(np.abs(patterns[i, :]))] > 0
                        else -patterns[i, :]
                    )
                    for i in range(patterns.shape[0])
                ]
            )

    def get_assembly_act(
        self, epoch: Optional[nel.EpochArray] = None
    ) -> nel.AnalogSignalArray:
        """
        Get assembly activity.

        Parameters
        ----------
        epoch : nel.EpochArray, optional
            The epoch to restrict to, by default None.

        Returns
        -------
        nel.AnalogSignalArray
            Assembly activity.
        """
        # check for num of assemblies first
        if self.n_assemblies() == 0:
            return nel.AnalogSignalArray(empty=True)

        if epoch is not None:
            zactmat, ts = self.get_z_mat(self.st[epoch])
        else:
            zactmat, ts = self.get_z_mat(self.st)

        assembly_act = nel.AnalogSignalArray(
            data=assembly.computeAssemblyActivity(self.patterns, zactmat),
            timestamps=ts,
            fs=1 / self.z_mat_dt,
        )
        return assembly_act

    def plot(
        self,
        plot_members: bool = True,
        central_line_color: str = "grey",
        marker_color: str = "k",
        member_color: Union[str, list] = "#6768ab",
        line_width: float = 1.25,
        markersize: float = 4,
        x_padding: float = 0.2,
        figsize: Union[tuple, None] = None,
    ) -> Union[Tuple[plt.Figure, np.ndarray], str, None]:
        """
        Plots basic stem plot to display assembly weights.

        Parameters
        ----------
        plot_members : bool, optional
            Whether to plot assembly members, by default True.
        central_line_color : str, optional
            Color of the central line, by default "grey".
        marker_color : str, optional
            Color of the markers, by default "k".
        member_color : Union[str, List[str]], optional
            Color of the members, by default "#6768ab".
        line_width : float, optional
            Width of the lines, by default 1.25.
        markersize : float, optional
            Size of the markers, by default 4.
        x_padding : float, optional
            Padding on the x-axis, by default 0.2.
        figsize : Optional[Tuple[float, float]], optional
            Size of the figure, by default None.

        Returns
        -------
        Union[Tuple[plt.Figure, np.ndarray], str, None]
            The figure and axes if successful, otherwise a message or None.
        """
        if not hasattr(self, "patterns"):
            return "run get_weights first"
        else:
            if self.patterns is None:
                return None, None
            if plot_members:
                self.find_members()
            if figsize is None:
                figsize = (self.n_assemblies() + 1, np.round(self.n_assemblies() / 2))
            # set up figure with size relative to assembly matrix
            fig, axes = plt.subplots(
                1,
                self.n_assemblies(),
                figsize=figsize,
                sharey=True,
                sharex=True,
            )
            # iter over each assembly and plot the weight per cell
            for i in range(self.n_assemblies()):
                markerline, stemlines, baseline = axes[i].stem(
                    self.patterns[i, :], orientation="horizontal"
                )
                markerline._color = marker_color
                baseline._color = central_line_color
                baseline.zorder = -1000
                plt.setp(stemlines, "color", plt.getp(markerline, "color"))
                plt.setp(stemlines, linewidth=line_width)
                plt.setp(markerline, markersize=markersize)

                if plot_members:
                    current_pattern = self.patterns[i, :].copy()
                    current_pattern[~self.assembly_members[i, :]] = np.nan
                    markerline, stemlines, baseline = axes[i].stem(
                        current_pattern, orientation="horizontal"
                    )
                    if isinstance(
                        member_color, sns.palettes._ColorPalette
                    ) or isinstance(member_color, list):
                        markerline._color = member_color[i]
                    else:
                        markerline._color = member_color
                    baseline._color = "#00000000"
                    baseline.zorder = -1000
                    plt.setp(stemlines, "color", plt.getp(markerline, "color"))
                    plt.setp(stemlines, linewidth=line_width)
                    plt.setp(markerline, markersize=markersize)

                axes[i].spines["top"].set_visible(False)
                axes[i].spines["right"].set_visible(False)

            # give room for marker
            axes[0].set_xlim(
                -self.patterns.max() - x_padding, self.patterns.max() + x_padding
            )

            axes[0].set_ylabel("Neurons #")
            axes[0].set_xlabel("Weights (a.u.)")

            return fig, axes

    def n_assemblies(self) -> int:
        """
        Get the number of detected assemblies.

        Returns
        -------
        int
            Number of detected assemblies.
        """
        if hasattr(self, "patterns"):
            if self.patterns is None:
                return 0
            return self.patterns.shape[0]

    @property
    def isempty(self) -> bool:
        """
        Check if the object is empty.

        Returns
        -------
        bool
            True if empty, False otherwise.
        """
        if hasattr(self, "st"):
            return False
        elif not hasattr(self, "st"):
            return True

    def copy(self) -> "AssemblyReact":
        """
        Returns a copy of the current class.

        Returns
        -------
        AssemblyReact
            A copy of the current class.
        """
        newcopy = copy.deepcopy(self)
        return newcopy

    def __repr__(self) -> str:
        if self.isempty:
            return f"<{self.type_name}: empty>"

        # if st data as been loaded and patterns have been computed
        if hasattr(self, "patterns"):
            n_units = f"{self.st.n_active} units"
            n_patterns = f"{self.n_assemblies()} assemblies"
            dstr = f"of length {self.st.support.length}"
            return "<%s: %s, %s> %s" % (self.type_name, n_units, n_patterns, dstr)

        # if st data as been loaded
        if hasattr(self, "st"):
            n_units = f"{self.st.n_active} units"
            dstr = f"of length {self.st.support.length}"
            return "<%s: %s> %s" % (self.type_name, n_units, dstr)

    def find_members(self) -> np.ndarray:
        """
        Finds significant assembly patterns and significant assembly members.

        Returns
        -------
        np.ndarray
            A ndarray of booleans indicating whether each unit is a significant member of an assembly.

        Notes
        -----
        also, sets self.assembly_members and self.valid_assembly

        self.valid_assembly: a ndarray of booleans indicating an assembly has members with the same sign (Boucly et al. 2022)
        """

        def Otsu(vector: np.ndarray) -> Tuple[np.ndarray, float, float]:
            """
            The Otsu method for splitting data into two groups.

            Parameters
            ----------
            vector : np.ndarray
                Arbitrary vector.

            Returns
            -------
            Tuple[np.ndarray, float, float]
                Group, threshold used for classification, and effectiveness metric.
            """
            sorted = np.sort(vector)
            n = len(vector)
            intraClassVariance = [np.nan] * n
            for i in np.arange(n):
                p = (i + 1) / n
                p0 = 1 - p
                if i + 1 == n:
                    intraClassVariance[i] = np.nan
                else:
                    intraClassVariance[i] = p * np.var(sorted[0 : i + 1]) + p0 * np.var(
                        sorted[i + 1 :]
                    )

            minIntraVariance = np.nanmin(intraClassVariance)
            idx = np.nanargmin(intraClassVariance)
            threshold = sorted[idx]
            group = vector > threshold

            em = 1 - (minIntraVariance / np.var(vector))

            return group, threshold, em

        is_member = []
        keep_assembly = []
        for pat in self.patterns:
            isMember, _, _ = Otsu(np.abs(pat))
            is_member.append(isMember)

            if np.any(pat[isMember] < 0) & np.any(pat[isMember] > 0):
                keep_assembly.append(False)
            elif sum(isMember) == 0:
                keep_assembly.append(False)
            else:
                keep_assembly.append(True)

        self.assembly_members = np.array(is_member)
        self.valid_assembly = np.array(keep_assembly)

        return self.assembly_members

isempty: bool property

Check if the object is empty.

Returns:

Type Description
bool

True if empty, False otherwise.

copy()

Returns a copy of the current class.

Returns:

Type Description
AssemblyReact

A copy of the current class.

Source code in neuro_py/ensemble/assembly_reactivation.py
448
449
450
451
452
453
454
455
456
457
458
def copy(self) -> "AssemblyReact":
    """
    Returns a copy of the current class.

    Returns
    -------
    AssemblyReact
        A copy of the current class.
    """
    newcopy = copy.deepcopy(self)
    return newcopy

find_members()

Finds significant assembly patterns and significant assembly members.

Returns:

Type Description
ndarray

A ndarray of booleans indicating whether each unit is a significant member of an assembly.

Notes

also, sets self.assembly_members and self.valid_assembly

self.valid_assembly: a ndarray of booleans indicating an assembly has members with the same sign (Boucly et al. 2022)

Source code in neuro_py/ensemble/assembly_reactivation.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def find_members(self) -> np.ndarray:
    """
    Finds significant assembly patterns and significant assembly members.

    Returns
    -------
    np.ndarray
        A ndarray of booleans indicating whether each unit is a significant member of an assembly.

    Notes
    -----
    also, sets self.assembly_members and self.valid_assembly

    self.valid_assembly: a ndarray of booleans indicating an assembly has members with the same sign (Boucly et al. 2022)
    """

    def Otsu(vector: np.ndarray) -> Tuple[np.ndarray, float, float]:
        """
        The Otsu method for splitting data into two groups.

        Parameters
        ----------
        vector : np.ndarray
            Arbitrary vector.

        Returns
        -------
        Tuple[np.ndarray, float, float]
            Group, threshold used for classification, and effectiveness metric.
        """
        sorted = np.sort(vector)
        n = len(vector)
        intraClassVariance = [np.nan] * n
        for i in np.arange(n):
            p = (i + 1) / n
            p0 = 1 - p
            if i + 1 == n:
                intraClassVariance[i] = np.nan
            else:
                intraClassVariance[i] = p * np.var(sorted[0 : i + 1]) + p0 * np.var(
                    sorted[i + 1 :]
                )

        minIntraVariance = np.nanmin(intraClassVariance)
        idx = np.nanargmin(intraClassVariance)
        threshold = sorted[idx]
        group = vector > threshold

        em = 1 - (minIntraVariance / np.var(vector))

        return group, threshold, em

    is_member = []
    keep_assembly = []
    for pat in self.patterns:
        isMember, _, _ = Otsu(np.abs(pat))
        is_member.append(isMember)

        if np.any(pat[isMember] < 0) & np.any(pat[isMember] > 0):
            keep_assembly.append(False)
        elif sum(isMember) == 0:
            keep_assembly.append(False)
        else:
            keep_assembly.append(True)

    self.assembly_members = np.array(is_member)
    self.valid_assembly = np.array(keep_assembly)

    return self.assembly_members

get_assembly_act(epoch=None)

Get assembly activity.

Parameters:

Name Type Description Default
epoch EpochArray

The epoch to restrict to, by default None.

None

Returns:

Type Description
AnalogSignalArray

Assembly activity.

Source code in neuro_py/ensemble/assembly_reactivation.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def get_assembly_act(
    self, epoch: Optional[nel.EpochArray] = None
) -> nel.AnalogSignalArray:
    """
    Get assembly activity.

    Parameters
    ----------
    epoch : nel.EpochArray, optional
        The epoch to restrict to, by default None.

    Returns
    -------
    nel.AnalogSignalArray
        Assembly activity.
    """
    # check for num of assemblies first
    if self.n_assemblies() == 0:
        return nel.AnalogSignalArray(empty=True)

    if epoch is not None:
        zactmat, ts = self.get_z_mat(self.st[epoch])
    else:
        zactmat, ts = self.get_z_mat(self.st)

    assembly_act = nel.AnalogSignalArray(
        data=assembly.computeAssemblyActivity(self.patterns, zactmat),
        timestamps=ts,
        fs=1 / self.z_mat_dt,
    )
    return assembly_act

get_weights(epoch=None)

Gets the assembly weights.

Parameters:

Name Type Description Default
epoch EpochArray

The epoch to restrict to, by default None.

None
Source code in neuro_py/ensemble/assembly_reactivation.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def get_weights(self, epoch: Optional[nel.EpochArray] = None) -> None:
    """
    Gets the assembly weights.

    Parameters
    ----------
    epoch : nel.EpochArray, optional
        The epoch to restrict to, by default None.
    """

    # check if st has any neurons
    if self.st.isempty:
        self.patterns = None
        return

    if epoch is not None:
        bst = self.st[epoch].bin(ds=self.weight_dt).data
    else:
        bst = self.st.bin(ds=self.weight_dt).data

    if (bst == 0).all():
        self.patterns = None
    else:
        patterns, _, _ = assembly.runPatterns(
            bst,
            method=self.method,
            nullhyp=self.nullhyp,
            nshu=self.nshu,
            percentile=self.percentile,
            tracywidom=self.tracywidom,
            whiten=self.whiten,
        )
        # flip patterns to have positive max
        self.patterns = np.array(
            [
                (
                    patterns[i, :]
                    if patterns[i, np.argmax(np.abs(patterns[i, :]))] > 0
                    else -patterns[i, :]
                )
                for i in range(patterns.shape[0])
            ]
        )

get_z_mat(st)

Get z matrix.

Parameters:

Name Type Description Default
st SpikeTrainArray

Spike train array.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Z-scored binned spike train and bin centers.

Source code in neuro_py/ensemble/assembly_reactivation.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_z_mat(self, st: nel.SpikeTrainArray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Get z matrix.

    Parameters
    ----------
    st : nel.SpikeTrainArray
        Spike train array.

    Returns
    -------
    Tuple[np.ndarray, np.ndarray]
        Z-scored binned spike train and bin centers.
    """
    # binning the spike train
    z_t = st.bin(ds=self.z_mat_dt)
    # gaussian kernel to match the bin-size used to identify the assembly patterns
    sigma = self.weight_dt / np.sqrt(int(1000 * self.weight_dt / 2))
    z_t.smooth(sigma=sigma, inplace=True)
    # zscore the z matrix
    z_scored_bst = stats.zscore(z_t.data, axis=1)
    # make sure there are no nans, important as strengths will all be nan otherwise
    z_scored_bst[np.isnan(z_scored_bst).any(axis=1)] = 0

    return z_scored_bst, z_t.bin_centers

load_data()

loads data (spikes,ripples,epochs) from the session folder

Source code in neuro_py/ensemble/assembly_reactivation.py
182
183
184
185
186
187
188
def load_data(self) -> None:
    """
    loads data (spikes,ripples,epochs) from the session folder
    """
    self.load_epoch()
    self.load_spikes()
    self.load_ripples()

load_epoch()

loads epochs from the session folder

Source code in neuro_py/ensemble/assembly_reactivation.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def load_epoch(self) -> None:
    """
    loads epochs from the session folder
    """
    epoch_df = loading.load_epoch(self.basepath)
    epoch_df = compress_repeated_epochs(epoch_df)
    self.time_support = nel.EpochArray(
        [epoch_df.iloc[0].startTime, epoch_df.iloc[-1].stopTime]
    )
    self.epochs = nel.EpochArray(
        [np.array([epoch_df.startTime, epoch_df.stopTime]).T],
        domain=self.time_support,
    )
    self.epoch_df = epoch_df

load_ripples()

loads ripples from the session folder

Source code in neuro_py/ensemble/assembly_reactivation.py
158
159
160
161
162
163
164
165
def load_ripples(self) -> None:
    """
    loads ripples from the session folder
    """
    ripples = loading.load_ripples_events(self.basepath)
    self.ripples = nel.EpochArray(
        [np.array([ripples.start, ripples.stop]).T], domain=self.time_support
    )

load_spikes()

loads spikes from the session folder

Source code in neuro_py/ensemble/assembly_reactivation.py
147
148
149
150
151
152
153
154
155
156
def load_spikes(self) -> None:
    """
    loads spikes from the session folder
    """
    self.st, self.cell_metrics = loading.load_spikes(
        self.basepath,
        brainRegion=self.brainRegion,
        putativeCellType=self.putativeCellType,
        support=self.time_support,
    )

n_assemblies()

Get the number of detected assemblies.

Returns:

Type Description
int

Number of detected assemblies.

Source code in neuro_py/ensemble/assembly_reactivation.py
419
420
421
422
423
424
425
426
427
428
429
430
431
def n_assemblies(self) -> int:
    """
    Get the number of detected assemblies.

    Returns
    -------
    int
        Number of detected assemblies.
    """
    if hasattr(self, "patterns"):
        if self.patterns is None:
            return 0
        return self.patterns.shape[0]

plot(plot_members=True, central_line_color='grey', marker_color='k', member_color='#6768ab', line_width=1.25, markersize=4, x_padding=0.2, figsize=None)

Plots basic stem plot to display assembly weights.

Parameters:

Name Type Description Default
plot_members bool

Whether to plot assembly members, by default True.

True
central_line_color str

Color of the central line, by default "grey".

'grey'
marker_color str

Color of the markers, by default "k".

'k'
member_color Union[str, List[str]]

Color of the members, by default "#6768ab".

'#6768ab'
line_width float

Width of the lines, by default 1.25.

1.25
markersize float

Size of the markers, by default 4.

4
x_padding float

Padding on the x-axis, by default 0.2.

0.2
figsize Optional[Tuple[float, float]]

Size of the figure, by default None.

None

Returns:

Type Description
Union[Tuple[Figure, ndarray], str, None]

The figure and axes if successful, otherwise a message or None.

Source code in neuro_py/ensemble/assembly_reactivation.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def plot(
    self,
    plot_members: bool = True,
    central_line_color: str = "grey",
    marker_color: str = "k",
    member_color: Union[str, list] = "#6768ab",
    line_width: float = 1.25,
    markersize: float = 4,
    x_padding: float = 0.2,
    figsize: Union[tuple, None] = None,
) -> Union[Tuple[plt.Figure, np.ndarray], str, None]:
    """
    Plots basic stem plot to display assembly weights.

    Parameters
    ----------
    plot_members : bool, optional
        Whether to plot assembly members, by default True.
    central_line_color : str, optional
        Color of the central line, by default "grey".
    marker_color : str, optional
        Color of the markers, by default "k".
    member_color : Union[str, List[str]], optional
        Color of the members, by default "#6768ab".
    line_width : float, optional
        Width of the lines, by default 1.25.
    markersize : float, optional
        Size of the markers, by default 4.
    x_padding : float, optional
        Padding on the x-axis, by default 0.2.
    figsize : Optional[Tuple[float, float]], optional
        Size of the figure, by default None.

    Returns
    -------
    Union[Tuple[plt.Figure, np.ndarray], str, None]
        The figure and axes if successful, otherwise a message or None.
    """
    if not hasattr(self, "patterns"):
        return "run get_weights first"
    else:
        if self.patterns is None:
            return None, None
        if plot_members:
            self.find_members()
        if figsize is None:
            figsize = (self.n_assemblies() + 1, np.round(self.n_assemblies() / 2))
        # set up figure with size relative to assembly matrix
        fig, axes = plt.subplots(
            1,
            self.n_assemblies(),
            figsize=figsize,
            sharey=True,
            sharex=True,
        )
        # iter over each assembly and plot the weight per cell
        for i in range(self.n_assemblies()):
            markerline, stemlines, baseline = axes[i].stem(
                self.patterns[i, :], orientation="horizontal"
            )
            markerline._color = marker_color
            baseline._color = central_line_color
            baseline.zorder = -1000
            plt.setp(stemlines, "color", plt.getp(markerline, "color"))
            plt.setp(stemlines, linewidth=line_width)
            plt.setp(markerline, markersize=markersize)

            if plot_members:
                current_pattern = self.patterns[i, :].copy()
                current_pattern[~self.assembly_members[i, :]] = np.nan
                markerline, stemlines, baseline = axes[i].stem(
                    current_pattern, orientation="horizontal"
                )
                if isinstance(
                    member_color, sns.palettes._ColorPalette
                ) or isinstance(member_color, list):
                    markerline._color = member_color[i]
                else:
                    markerline._color = member_color
                baseline._color = "#00000000"
                baseline.zorder = -1000
                plt.setp(stemlines, "color", plt.getp(markerline, "color"))
                plt.setp(stemlines, linewidth=line_width)
                plt.setp(markerline, markersize=markersize)

            axes[i].spines["top"].set_visible(False)
            axes[i].spines["right"].set_visible(False)

        # give room for marker
        axes[0].set_xlim(
            -self.patterns.max() - x_padding, self.patterns.max() + x_padding
        )

        axes[0].set_ylabel("Neurons #")
        axes[0].set_xlabel("Weights (a.u.)")

        return fig, axes

restrict_epochs_to_pre_task_post()

Restricts the epochs to the specified epochs

Source code in neuro_py/ensemble/assembly_reactivation.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def restrict_epochs_to_pre_task_post(self) -> None:
    """
    Restricts the epochs to the specified epochs
    """
    # fetch data
    epoch_df = loading.load_epoch(self.basepath)
    # compress back to back sleep epochs (an issue further up the pipeline)
    epoch_df = compress_repeated_epochs(epoch_df)
    # restrict to pre task post epochs
    idx = find_pre_task_post(epoch_df.environment)
    self.epoch_df = epoch_df[idx[0]]
    # convert to epoch array and add to object
    self.epochs = nel.EpochArray(
        [np.array([self.epoch_df.startTime, self.epoch_df.stopTime]).T],
        label="session_epochs",
        domain=self.time_support,
    )

restrict_to_epoch(epoch)

Restricts the spike data to a specific epoch.

Parameters:

Name Type Description Default
epoch EpochArray

The epoch to restrict to.

required
Source code in neuro_py/ensemble/assembly_reactivation.py
208
209
210
211
212
213
214
215
216
217
def restrict_to_epoch(self, epoch) -> None:
    """
    Restricts the spike data to a specific epoch.

    Parameters
    ----------
    epoch : nel.EpochArray
        The epoch to restrict to.
    """
    self.st_resticted = self.st[epoch]