neuro_py.plotting.replay
make_replay(nx=50, ny=50, T=15, kind='linear', seed=42)
¶
kind : "linear" — straight trajectory across arena "curved" — arc trajectory "diffuse" — wide, uncertain posteriors (stress-tests saturation)
Notes
- This is just a helper function to generate demo replay matrices for the tutorial docstring.
Source code in neuro_py/plotting/replay.py
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | |
plot_2d_replay(replay_matrix, ax=None, cmap='cool', extent=None, saturation=3, percentile_threshold=99, abs_threshold=None, per_frame_alpha_normalization=True)
¶
Plot a single 2D replay event.
Each time bin is drawn as a separate RGBA layer; matplotlib composites them naturally. Color encodes elapsed time within the replay (early→late following the chosen colormap). Alpha is power-scaled from frame probabilities using either per-frame normalization (default) or global normalization across all frames.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
replay_matrix
|
(ndarray, shape(nx, ny, T))
|
Decoded probability distributions over space. Each [:, :, t] slice should be non-negative and ideally sum to ~1. |
required |
ax
|
Axes
|
Axes to draw on. If None, a new figure is created. |
None
|
cmap
|
str or Colormap
|
Colormap used to color time bins. Default "cool" gives cyan→magenta. |
'cool'
|
extent
|
array - like[xmin, xmax, ymin, ymax]
|
Spatial extent in data coordinates. Defaults to bin indices. |
None
|
saturation
|
float, > 0
|
Controls how much of the probability distribution is visible via alpha = (p / norm_max) ** (1 / saturation), where norm_max is frame.max() when per_frame_alpha_normalization=True and global_max when per_frame_alpha_normalization=False. saturation=1 → exponent=1, alpha scales linearly with probability. saturation<1 → exponent>1, low-probability regions fade faster (sparse). saturation>1 → exponent<1, low-probability regions boosted (dense/flat). |
3
|
percentile_threshold
|
float
|
Per-frame values below this percentile are zeroed out. Combined with abs_threshold — both must pass. |
99
|
abs_threshold
|
float
|
Absolute floor applied alongside percentile_threshold. Prevents near-zero values in sparse frames from leaking through. |
None
|
per_frame_alpha_normalization
|
bool
|
If True, alpha is normalized by each frame's max (default). If False, alpha is normalized by the global max across all frames. The latter preserves relative intensity across frames, but may cause low-probability frames to be very faint. |
True
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
Matplotlib figure and axes containing the replay plot. |
Examples:
Figure 1: three replay types¶
>>> fig, axes = plt.subplots(1, 3, figsize=(12, 4))
>>> for ax, kind, title in zip(
... axes, ["linear", "curved", "diffuse"], ["Linear", "Curved", "Diffuse (wide)"]
... ):
... plot_2d_replay(make_replay(kind=kind), ax=ax, saturation=0.5)
... ax.set_title(title)
>>> fig.suptitle("Replay types", y=1.02)
>>> fig.tight_layout()
>>> plt.show()
Figure 2: saturation comparison¶
>>> mat = make_replay(kind="curved")
>>> sat_values = [0.1, 0.5, 1.0, 2.0]
>>> fig, axes = plt.subplots(1, 4, figsize=(14, 4))
>>> for ax, sat in zip(axes, sat_values):
... plot_2d_replay(mat, ax=ax, saturation=sat)
... ax.set_title(f"saturation={sat}")
>>> fig.suptitle("Saturation comparison", y=1.02)
>>> fig.tight_layout()
>>> plt.show()
Figure 3: per-frame vs global alpha normalization¶
>>> fig, axes = plt.subplots(1, 2, figsize=(8, 4))
>>> plot_2d_replay(mat, ax=axes[0], per_frame_alpha_normalization=True)
>>> axes[0].set_title("Per-frame normalization")
>>> plot_2d_replay(mat, ax=axes[1], per_frame_alpha_normalization=False)
>>> axes[1].set_title("Global normalization")
>>> fig.tight_layout()
>>> plt.show()
Source code in neuro_py/plotting/replay.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | |