Skip to content

Preprocessing Reference

Crane provides various preprocessing utilities to prepare neural data for modeling and analysis.

Preprocessors

Spectrogram

Bases: Module

Spectrogram Preprocessor, computes the spectrogram of iEEG data using STFT.

Parameters:

Name Type Description Default
segment_length float

The length of each segment (in seconds) for the STFT.

required
p_overlap float

The proportion of overlap between segments (between 0 and 1).

required
min_frequency float

The minimum frequency (in Hz) to include in the spectrogram.

required
max_frequency float

The maximum frequency (in Hz) to include in the spectrogram.

required
window Literal['hann', 'boxcar']

The type of window to use for the STFT.

required
remove_line_noise bool

Whether to remove line noise frequencies (e.g., 50/60 Hz).

required
output_dim int, default=-1

The dimension of the output features. If -1, the output feature dimension will be the same as the number of frequency bins. Otherwise, they will be projected to this dimension.

-1
Source code in crane/preprocess/spectrogram.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class Spectrogram(nn.Module):
    """
    Spectrogram Preprocessor, computes the spectrogram of iEEG data using STFT.

    Args:
        segment_length (float): The length of each segment (in seconds) for the STFT.
        p_overlap (float): The proportion of overlap between segments (between 0 and 1).
        *,
        min_frequency (float): The minimum frequency (in Hz) to include in the spectrogram.
        max_frequency (float): The maximum frequency (in Hz) to include in the spectrogram.
        window (Literal["hann", "boxcar"]): The type of window to use for the STFT.
        remove_line_noise (bool): Whether to remove line noise frequencies (e.g., 50/60 Hz).
        output_dim (int, default=-1): The dimension of the output features. If -1, the output feature dimension will be the same as the number of frequency bins. Otherwise, they will be projected to this dimension.
    """

    def __init__(
        self,
        segment_length: float,
        p_overlap: float,
        *,
        min_frequency: float,
        max_frequency: float,
        window: Literal["hann", "boxcar"],
        remove_line_noise: bool,
        output_dim: int = -1,
    ):
        super().__init__()
        self.segment_length = segment_length
        self.p_overlap = p_overlap

        self.min_frequency = min_frequency
        self.max_frequency = max_frequency
        self.window = window
        self.remove_line_noise = remove_line_noise

        self.output_dim = output_dim

        # from https://docs.pytorch.org/docs/stable/generated/torch.fft.rfftfreq.html
        # if n is nperseg, and d is 1/sampling_rate, then f = torch.arange((n + 1) // 2) / (d * n)
        # note: nperseg is always going to be even, so it simplifies to torch.arange(n/2) / n * sampling_rate
        # note: n = sampling_rate * tperseg, so it simplifies to torch.arange(sampling_rate * tperseg / 2) / tperseg
        #    which is a list that goes from 0 to sampling_rate / 2 in increments of sampling_rate / nperseg = 1 / tperseg
        # so max frequency bin is max_frequency * tperseg + 1 (adding one to make the endpoint inclusive)
        self.max_frequency_bin = round(self.max_frequency * self.segment_length + 1)
        self.min_frequency_bin = round(self.min_frequency * self.segment_length)
        self.n_freqs = self.max_frequency_bin - self.min_frequency_bin

        # Transform FFT output to match expected output dimension
        self.output_transform = nn.Identity() if self.output_dim == -1 else nn.Linear(self.n_freqs, self.output_dim)

        if self.remove_line_noise:
            example_sampling_rate = 2048
            nperseg = round(self.segment_length * example_sampling_rate)
            freq_bins = torch.fft.rfftfreq(nperseg, d=1.0 / example_sampling_rate)[
                self.min_frequency_bin : self.max_frequency_bin
            ]  # Calculate frequency bins (in Hz)
            self.line_noise_mask = self._compute_line_noise_mask(
                freq_bins=freq_bins, line_noise_freqs=[50, 60], margin=2.0
            )
        else:
            self.line_noise_mask = None

    @allow_inplace
    def forward(self, data: CraneFeature, *, z_score: bool = True) -> CraneFeature:
        """
        Perform the forward pass of the SpectrogramPreprocessor.

        Args:
            data (torch.Tensor): A tensor of shape (batch_size, n_electrodes, n_samples) representing the iEEG data.
            sampling_rate (int): An integer representing the sampling rate of the iEEG data.
            z_score (bool): Whether to apply z-score normalization to the spectrogram. Default is True.

        Returns:
            torch.Tensor: The processed spectrogram with shape (batch_size, n_electrodes, n_timebins, n_freqs or output_dim).
        """
        if data.batched:
            batch_size, n_electrodes, _ = data.signals.shape
        else:
            batch_size = 1
            n_electrodes = data.signals.shape[0]

        # Reshape for STFT
        x = data.signals.reshape(batch_size * n_electrodes, -1)
        x = x.to(dtype=torch.float32)  # Convert to float32 for STFT

        # STFT parameters
        nperseg = round(self.segment_length * data.sampling_rate)
        noverlap = round(self.p_overlap * nperseg)
        hop_length = nperseg - noverlap

        window = {
            "hann": torch.hann_window,
            "boxcar": torch.ones,
        }[self.window](nperseg, device=x.device)

        # Compute STFT
        x = torch.stft(
            x,
            n_fft=nperseg,
            hop_length=hop_length,
            win_length=nperseg,
            window=window,
            return_complex=True,
            normalized=False,
            center=True,
        )

        # Take magnitude
        x = torch.abs(x)

        # Calculate frequency bins (in Hz)
        # These represent the center frequency of each frequency bin in the spectrogram
        freq_bins = torch.fft.rfftfreq(nperseg, d=1.0 / data.sampling_rate, device=x.device)

        # Calculate time bins (in seconds)
        # These represent the center time of each time window in the spectrogram
        n_times = x.shape[2]
        # time_bins = (
        #     torch.arange(n_times, device=x.device, dtype=torch.float32)
        #     * hop_length
        #     / sampling_rate
        # )

        # Trim to max frequency (using a pre-calculated max frequency bin)
        x = x[:, self.min_frequency_bin : self.max_frequency_bin, :]
        freq_bins = freq_bins[self.min_frequency_bin : self.max_frequency_bin]

        # Reshape back
        _, n_freqs, n_times = x.shape
        x = x.reshape(batch_size, n_electrodes, n_freqs, n_times)
        x = x.transpose(2, 3)  # (batch_size, n_electrodes, n_timebins, n_freqs)

        # Z-score normalization
        if z_score:
            x = x - x.mean(dim=[0, 2], keepdim=True)
            x = x / (x.std(dim=[0, 2], keepdim=True) + 1e-5)

        if self.line_noise_mask is not None:  # If removing line noise, set line noise to 0
            self.line_noise_mask = self.line_noise_mask.to(x.device)
            x = x.masked_fill(self.line_noise_mask.view(1, 1, 1, -1), 0)

        # Transform to match expected output dimension
        x = self.output_transform(x)  # shape: (batch_size, n_electrodes, n_timebins, output_dim)
        x = x.to(dtype=data.signals.dtype)

        out = data.copy()
        out.signals = x
        return out

    def _compute_line_noise_mask(
        self,
        freq_bins: torch.Tensor,
        line_noise_freqs: list | None = None,
        margin: float = 2.0,
    ) -> torch.Tensor:
        """
        Compute a mask for line noise frequencies in the spectrogram.

        Args:
            freq_bins (torch.Tensor): The frequency bins of the spectrogram.
            line_noise_freqs (list, optional): The line noise frequencies to mask. If none, defaults to [50, 60].
            margin (float, optional): The margin around the line noise frequencies to include in the mask. Defaults to 2.0.

        Returns:
            torch.Tensor: A boolean mask indicating the line noise frequencies.
        """
        # 60 Hz and its harmonics
        line_noise_mask = torch.zeros(freq_bins.shape[0], device=freq_bins.device, dtype=torch.bool)

        if line_noise_freqs is None:
            line_noise_freqs = [50, 60]

        for freq in line_noise_freqs:
            line_noise_mask |= torch.abs(freq_bins - freq) <= margin

        return line_noise_mask

forward(data, *, z_score=True)

Perform the forward pass of the SpectrogramPreprocessor.

Parameters:

Name Type Description Default
data Tensor

A tensor of shape (batch_size, n_electrodes, n_samples) representing the iEEG data.

required
sampling_rate int

An integer representing the sampling rate of the iEEG data.

required
z_score bool

Whether to apply z-score normalization to the spectrogram. Default is True.

True

Returns:

Type Description
CraneFeature

torch.Tensor: The processed spectrogram with shape (batch_size, n_electrodes, n_timebins, n_freqs or output_dim).

Source code in crane/preprocess/spectrogram.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@allow_inplace
def forward(self, data: CraneFeature, *, z_score: bool = True) -> CraneFeature:
    """
    Perform the forward pass of the SpectrogramPreprocessor.

    Args:
        data (torch.Tensor): A tensor of shape (batch_size, n_electrodes, n_samples) representing the iEEG data.
        sampling_rate (int): An integer representing the sampling rate of the iEEG data.
        z_score (bool): Whether to apply z-score normalization to the spectrogram. Default is True.

    Returns:
        torch.Tensor: The processed spectrogram with shape (batch_size, n_electrodes, n_timebins, n_freqs or output_dim).
    """
    if data.batched:
        batch_size, n_electrodes, _ = data.signals.shape
    else:
        batch_size = 1
        n_electrodes = data.signals.shape[0]

    # Reshape for STFT
    x = data.signals.reshape(batch_size * n_electrodes, -1)
    x = x.to(dtype=torch.float32)  # Convert to float32 for STFT

    # STFT parameters
    nperseg = round(self.segment_length * data.sampling_rate)
    noverlap = round(self.p_overlap * nperseg)
    hop_length = nperseg - noverlap

    window = {
        "hann": torch.hann_window,
        "boxcar": torch.ones,
    }[self.window](nperseg, device=x.device)

    # Compute STFT
    x = torch.stft(
        x,
        n_fft=nperseg,
        hop_length=hop_length,
        win_length=nperseg,
        window=window,
        return_complex=True,
        normalized=False,
        center=True,
    )

    # Take magnitude
    x = torch.abs(x)

    # Calculate frequency bins (in Hz)
    # These represent the center frequency of each frequency bin in the spectrogram
    freq_bins = torch.fft.rfftfreq(nperseg, d=1.0 / data.sampling_rate, device=x.device)

    # Calculate time bins (in seconds)
    # These represent the center time of each time window in the spectrogram
    n_times = x.shape[2]
    # time_bins = (
    #     torch.arange(n_times, device=x.device, dtype=torch.float32)
    #     * hop_length
    #     / sampling_rate
    # )

    # Trim to max frequency (using a pre-calculated max frequency bin)
    x = x[:, self.min_frequency_bin : self.max_frequency_bin, :]
    freq_bins = freq_bins[self.min_frequency_bin : self.max_frequency_bin]

    # Reshape back
    _, n_freqs, n_times = x.shape
    x = x.reshape(batch_size, n_electrodes, n_freqs, n_times)
    x = x.transpose(2, 3)  # (batch_size, n_electrodes, n_timebins, n_freqs)

    # Z-score normalization
    if z_score:
        x = x - x.mean(dim=[0, 2], keepdim=True)
        x = x / (x.std(dim=[0, 2], keepdim=True) + 1e-5)

    if self.line_noise_mask is not None:  # If removing line noise, set line noise to 0
        self.line_noise_mask = self.line_noise_mask.to(x.device)
        x = x.masked_fill(self.line_noise_mask.view(1, 1, 1, -1), 0)

    # Transform to match expected output dimension
    x = self.output_transform(x)  # shape: (batch_size, n_electrodes, n_timebins, output_dim)
    x = x.to(dtype=data.signals.dtype)

    out = data.copy()
    out.signals = x
    return out

laplacian_rereference(data, remove_non_laplacian=True)

Apply Laplacian rereferencing to a batch of neural data (subtract the mean of the neighbors, as determined by the electrode labels)

Parameters:

Name Type Description Default
data CraneFeature

CraneFeature containing the neural data and channel information

required
remove_non_laplacian bool

if True, remove the non-laplacian electrodes from the data; if false, keep them without rereferencing

True

Returns:

Type Description
CraneFeature

rereferenced_data,updated_channels (tuple[torch.Tensor, list of str or ChannelDict]): A tuple containing: rereferenced_data (torch.Tensor): torch tensor of shape (batch_size, n_electrodes_rereferenced, n_samples) or (n_electrodes_rereferenced, n_samples) updated_channels (list of str or ChannelDict): list of electrode labels or ChannelDict of length n_electrodes_rereferenced (n_electrodes_rereferenced could be different from n_electrodes if remove_non_laplacian is True)

Source code in crane/preprocess/rereference.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@allow_inplace
def laplacian_rereference(data: CraneFeature, remove_non_laplacian: bool = True) -> CraneFeature:
    """
    Apply Laplacian rereferencing to a batch of neural data
    (subtract the mean of the neighbors, as determined by the electrode labels)

    Args:
        data (CraneFeature): CraneFeature containing the neural data and channel information
        remove_non_laplacian (bool): if True, remove the non-laplacian electrodes from the data; if false, keep them without rereferencing

    Returns:
        rereferenced_data,updated_channels (tuple[torch.Tensor, list of str or ChannelDict]): A tuple containing:
            rereferenced_data (torch.Tensor): torch tensor of shape (batch_size, n_electrodes_rereferenced, n_samples) or (n_electrodes_rereferenced, n_samples)
            updated_channels (list of str or ChannelDict): list of electrode labels or ChannelDict of length n_electrodes_rereferenced (n_electrodes_rereferenced could be different from n_electrodes if remove_non_laplacian is True)
    """

    electrode_labels = data.channel_labels

    # _rereference_electrodes expects (batch_size, n_electrodes, n_samples) or (n_electrodes, n_samples)
    rereferenced_data, rereferenced_labels, _ = _rereference_electrodes(
        data.signals, electrode_labels, remove_non_laplacian=remove_non_laplacian
    )

    # Update with rereferenced data
    label_set = set(rereferenced_labels)
    indices = [i for i, label in enumerate(electrode_labels) if label in label_set]

    data.signals = rereferenced_data
    data.channel_labels = rereferenced_labels
    data.channel_coordinates = data.channel_coordinates[indices]

    return data

subset_electrodes(data, *, max_n_electrodes=None, subset=None)

Subset channel electrodes, consistent across a batch.

Exactly one of max_n_electrodes or subset must be provided.

Parameters:

Name Type Description Default
data CraneFeature

The input feature data.

required
max_n_electrodes int

Maximum number of randomly selected electrodes to subset to.

None
subset Sequence[int | str] | None

Optional list of electrode indices or IDs to subset to. If None, a random subset is chosen.

None

Returns:

Name Type Description
CraneFeature CraneFeature

The subsetted feature data.

Source code in crane/preprocess/subset.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@allow_inplace
def subset_electrodes(
    data: CraneFeature,
    *,
    max_n_electrodes: int | None = None,
    subset: Sequence[int] | Sequence[str] | None = None,
) -> CraneFeature:
    """
    Subset channel electrodes, consistent across a batch.

    Exactly one of `max_n_electrodes` or `subset` must be provided.

    Args:
        data (CraneFeature): The input feature data.
        max_n_electrodes (int): Maximum number of randomly selected electrodes to subset to.
        subset (Sequence[int | str] | None): Optional list of electrode indices or IDs to subset to. If None, a random subset is chosen.

    Returns:
        CraneFeature: The subsetted feature data.
    """

    if (subset is None) == (max_n_electrodes is None):
        raise ValueError("Provide exactly one of max_n_electrodes or subset")

    # If no subset provided, randomly select electrodes up to max_n_electrodes
    if subset is None:
        n = data.signals.shape[data.channel_dim]
        if max_n_electrodes is None or n <= max_n_electrodes:
            return data
        indices = torch.randperm(n, device=data.device)[:max_n_electrodes]

    # If string IDs provided, map to indices using channel_labels
    elif all(isinstance(e, str) for e in subset):
        subset = cast(Sequence[str], subset)

        id_to_idx = {cid: i for i, cid in enumerate(data.channel_labels)}
        indices = torch.tensor(
            [id_to_idx[e] for e in subset if e in id_to_idx],
            dtype=torch.long,
            device=data.device,
        )

    # If integer indices provided, use directly
    else:
        indices = torch.as_tensor(subset, dtype=torch.long, device=data.device)

    data.signals = torch.index_select(data.signals, data.channel_dim, indices)
    data.channel_coordinates = torch.index_select(data.channel_coordinates, 0, indices)
    data.channel_labels = [data.channel_labels[i] for i in indices.tolist()]

    return data