senselab.audio.tasks.preprocessing.preprocessing

This module implements some utilities for the preprocessing task.

  1"""This module implements some utilities for the preprocessing task."""
  2
  3try:
  4    from speechbrain.augment.time_domain import Resample
  5
  6    SPEECHBRAIN_AVAILABLE = True
  7except ModuleNotFoundError:
  8    SPEECHBRAIN_AVAILABLE = False
  9
 10
 11from typing import List, Optional, Tuple
 12
 13import torch
 14from scipy import signal
 15
 16from senselab.audio.data_structures import Audio
 17
 18
 19def resample_audios(
 20    audios: List[Audio],
 21    resample_rate: int,
 22    lowcut: Optional[float] = None,
 23    order: int = 4,
 24) -> List[Audio]:
 25    """Resamples a list of audio signals to a given sampling rate.
 26
 27    Args:
 28        audios (List[Audio]): List of audio objects to resample.
 29        resample_rate (int): Target sampling rate.
 30        lowcut (float, optional): Low cut frequency for IIR filter.
 31        order (int, optional): Order of the IIR filter. Defaults to 4.
 32
 33    Returns:
 34        List[Audio]: Resampled audio objects.
 35    """
 36    if not SPEECHBRAIN_AVAILABLE:
 37        raise ModuleNotFoundError(
 38            "`speechbrain` is not installed. "
 39            "Please install senselab audio dependencies using `pip install 'senselab[audio]'`."
 40        )
 41
 42    resampled_audios = []
 43    for audio in audios:
 44        if lowcut is None:
 45            lowcut = resample_rate / 2 - 100
 46        sos = signal.butter(order, lowcut, btype="low", output="sos", fs=resample_rate)
 47
 48        channels = []
 49        for channel in audio.waveform:
 50            filtered_channel = torch.from_numpy(signal.sosfiltfilt(sos, channel.numpy()).copy()).float()
 51            resampler = Resample(orig_freq=audio.sampling_rate, new_freq=resample_rate)
 52            resampled_channel = resampler(filtered_channel.unsqueeze(0)).squeeze(0)
 53            channels.append(resampled_channel)
 54
 55        resampled_waveform = torch.stack(channels)
 56        resampled_audios.append(
 57            Audio(
 58                waveform=resampled_waveform,
 59                sampling_rate=resample_rate,
 60                metadata=audio.metadata.copy(),
 61            )
 62        )
 63    return resampled_audios
 64
 65
 66def downmix_audios_to_mono(audios: List[Audio]) -> List[Audio]:
 67    """Downmixes a list of Audio objects to mono by averaging all channels.
 68
 69    Args:
 70        audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform.
 71                                 Shape: (num_channels, num_samples).
 72
 73    Returns:
 74        List[Audio]: The list of audio objects with a mono waveform averaged from all channels. Shape: (num_samples).
 75    """
 76    down_mixed_audios = []
 77    for audio in audios:
 78        down_mixed_audios.append(
 79            Audio(
 80                waveform=audio.waveform.mean(dim=0, keepdim=True),
 81                sampling_rate=audio.sampling_rate,
 82                metadata=audio.metadata.copy(),
 83            )
 84        )
 85
 86    return down_mixed_audios
 87
 88
 89def select_channel_from_audios(audios: List[Audio], channel_index: int) -> List[Audio]:
 90    """Selects a specific channel from a list of Audio objects.
 91
 92    Args:
 93        audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform.
 94                              Shape: (num_channels, num_samples).
 95        channel_index (int): The index of the channel to select.
 96
 97    Returns:
 98        List[Audio]: The list of audio objects with the selected channel. Shape: (1, num_samples).
 99    """
100    mono_channel_audios = []
101    for audio in audios:
102        if audio.waveform.size(0) <= channel_index:  # should consider how much sense negative values make
103            raise ValueError("channel_index should be valid")
104
105        mono_channel_audios.append(
106            Audio(
107                waveform=audio.waveform[channel_index, :],
108                sampling_rate=audio.sampling_rate,
109                metadata=audio.metadata.copy(),
110            )
111        )
112    return mono_channel_audios
113
114
115def chunk_audios(data: List[Tuple[Audio, Tuple[float, float]]]) -> List[Audio]:
116    """Chunks the input audios based on the start and end timestamp.
117
118    Args:
119        data: List of tuples containing an Audio object and a tuple with start and end (in seconds) for chunking.
120
121    Returns:
122        List of Audios that have been chunked based on the provided timestamps
123
124    Todo:
125        Do we really need both chunk_audios and extract_segments?
126    """
127    chunked_audios = []
128
129    for audio, timestamps in data:
130        start, end = timestamps
131        if start < 0:
132            raise ValueError("Start time must be greater than or equal to 0.")
133        duration = audio.waveform.shape[1] / audio.sampling_rate
134        if end > duration:
135            raise ValueError(f"End time must be less than the duration of the audio file ({duration} seconds).")
136        start_sample = int(start * audio.sampling_rate)
137        end_sample = int(end * audio.sampling_rate)
138        chunked_waveform = audio.waveform[:, start_sample:end_sample]
139
140        chunked_audios.append(
141            Audio(
142                waveform=chunked_waveform,
143                sampling_rate=audio.sampling_rate,
144                metadata=audio.metadata.copy(),
145            )
146        )
147    return chunked_audios
148
149
150def extract_segments(data: List[Tuple[Audio, List[Tuple[float, float]]]]) -> List[List[Audio]]:
151    """Extracts segments from an audio file.
152
153    Args:
154        data: List of tuples containing an Audio object and a list of tuples with start
155            and end (in seconds) for chunking.
156
157    Returns:
158        List of lists of Audios that have been chunked based on the provided timestamps.
159    """
160    extracted_segments = []
161    for audio, timestamps in data:
162        segments_data = [(audio, ts) for ts in timestamps]
163        single_audio_segments = chunk_audios(segments_data)
164        extracted_segments.append(single_audio_segments)
165    return extracted_segments
166
167
168def pad_audios(audios: List[Audio], desired_samples: int) -> List[Audio]:
169    """Pads the audio segment to the desired length.
170
171    Args:
172        audios: The list of audio objects to be padded.
173        desired_samples: The desired length (in samples) for the padded audio.
174
175    Returns:
176        A new Audio object with the waveform padded to the desired length.
177    """
178    padded_audios = []
179    for audio in audios:
180        current_samples = audio.waveform.shape[1]
181
182        if current_samples >= desired_samples:
183            return [audio]
184
185        padding_needed = desired_samples - current_samples
186        padded_waveform = torch.nn.functional.pad(audio.waveform, (0, padding_needed))
187        padded_audio = Audio(
188            waveform=padded_waveform,
189            sampling_rate=audio.sampling_rate,
190            metadata=audio.metadata.copy(),
191        )
192        padded_audios.append(padded_audio)
193    return padded_audios
194
195
196def evenly_segment_audios(
197    audios: List[Audio], segment_length: float, pad_last_segment: bool = True
198) -> List[List[Audio]]:
199    """Segments multiple audio files into evenly sized segments with optional padding for the last segment.
200
201    Args:
202        audios: The list of Audio objects to be segmented.
203        segment_length: The desired length of each segment in seconds.
204        pad_last_segment: Whether to pad the last segment to the full segment length (default is False).
205
206    Returns:
207        List of Audio objects that have been segmented.
208    """
209    audios_and_segment_timestamps = []
210    for i, audio in enumerate(audios):
211        total_duration = audio.waveform.shape[1] / audio.sampling_rate
212        segment_samples = int(segment_length * audio.sampling_rate)
213
214        # Create a list of tuples with start and end times for each segment
215        timestamps = [
216            (i * segment_length, (i + 1) * segment_length) for i in range(int(total_duration // segment_length))
217        ]
218        if total_duration % segment_length != 0:
219            timestamps.append((total_duration - (total_duration % segment_length), total_duration))
220        audios_and_segment_timestamps.append((audio, timestamps))
221
222    audio_segment_lists = extract_segments([(audio, timestamps)])
223
224    for i, audio_segment_list in enumerate(audio_segment_lists):
225        if pad_last_segment and len(audio_segment_list) > 0:
226            last_segment = audio_segment_list[-1]
227            if last_segment.waveform.shape[1] < segment_samples:
228                audio_segment_lists[i][-1] = pad_audios([last_segment], segment_samples)[0]
229
230    return audio_segment_lists
231
232
233def concatenate_audios(audios: List[Audio]) -> Audio:
234    """Concatenates all audios in the list, ensuring they have the same sampling rate and shape.
235
236    Args:
237        audios: List of Audio objects to concatenate.
238
239    Returns:
240        A single Audio object that is the concatenation of all input audios.
241
242    Raises:
243        ValueError: If the audios do not all have the same sampling rate or shape.
244    """
245    if not audios:
246        raise ValueError("The input list is empty. Please provide a list with at least one Audio object.")
247
248    sampling_rate = audios[0].sampling_rate
249    num_channels = audios[0].waveform.shape[0]
250
251    for audio in audios:
252        if audio.sampling_rate != sampling_rate:
253            raise ValueError("All audios must have the same sampling rate to concatenate.")
254        if audio.waveform.shape[0] != num_channels:
255            raise ValueError("All audios must have the same number of channels (mono or stereo) to concatenate.")
256
257    concatenated_waveform = torch.cat([audio.waveform.cpu() for audio in audios], dim=1)
258
259    # TODO: do we want to concatenate metadata? TBD
260
261    return Audio(
262        waveform=concatenated_waveform,
263        sampling_rate=sampling_rate,
264    )
def resample_audios( audios: List[senselab.audio.data_structures.audio.Audio], resample_rate: int, lowcut: Optional[float] = None, order: int = 4) -> List[senselab.audio.data_structures.audio.Audio]:
20def resample_audios(
21    audios: List[Audio],
22    resample_rate: int,
23    lowcut: Optional[float] = None,
24    order: int = 4,
25) -> List[Audio]:
26    """Resamples a list of audio signals to a given sampling rate.
27
28    Args:
29        audios (List[Audio]): List of audio objects to resample.
30        resample_rate (int): Target sampling rate.
31        lowcut (float, optional): Low cut frequency for IIR filter.
32        order (int, optional): Order of the IIR filter. Defaults to 4.
33
34    Returns:
35        List[Audio]: Resampled audio objects.
36    """
37    if not SPEECHBRAIN_AVAILABLE:
38        raise ModuleNotFoundError(
39            "`speechbrain` is not installed. "
40            "Please install senselab audio dependencies using `pip install 'senselab[audio]'`."
41        )
42
43    resampled_audios = []
44    for audio in audios:
45        if lowcut is None:
46            lowcut = resample_rate / 2 - 100
47        sos = signal.butter(order, lowcut, btype="low", output="sos", fs=resample_rate)
48
49        channels = []
50        for channel in audio.waveform:
51            filtered_channel = torch.from_numpy(signal.sosfiltfilt(sos, channel.numpy()).copy()).float()
52            resampler = Resample(orig_freq=audio.sampling_rate, new_freq=resample_rate)
53            resampled_channel = resampler(filtered_channel.unsqueeze(0)).squeeze(0)
54            channels.append(resampled_channel)
55
56        resampled_waveform = torch.stack(channels)
57        resampled_audios.append(
58            Audio(
59                waveform=resampled_waveform,
60                sampling_rate=resample_rate,
61                metadata=audio.metadata.copy(),
62            )
63        )
64    return resampled_audios

Resamples a list of audio signals to a given sampling rate.

Arguments:
  • audios (List[Audio]): List of audio objects to resample.
  • resample_rate (int): Target sampling rate.
  • lowcut (float, optional): Low cut frequency for IIR filter.
  • order (int, optional): Order of the IIR filter. Defaults to 4.
Returns:

List[Audio]: Resampled audio objects.

def downmix_audios_to_mono( audios: List[senselab.audio.data_structures.audio.Audio]) -> List[senselab.audio.data_structures.audio.Audio]:
67def downmix_audios_to_mono(audios: List[Audio]) -> List[Audio]:
68    """Downmixes a list of Audio objects to mono by averaging all channels.
69
70    Args:
71        audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform.
72                                 Shape: (num_channels, num_samples).
73
74    Returns:
75        List[Audio]: The list of audio objects with a mono waveform averaged from all channels. Shape: (num_samples).
76    """
77    down_mixed_audios = []
78    for audio in audios:
79        down_mixed_audios.append(
80            Audio(
81                waveform=audio.waveform.mean(dim=0, keepdim=True),
82                sampling_rate=audio.sampling_rate,
83                metadata=audio.metadata.copy(),
84            )
85        )
86
87    return down_mixed_audios

Downmixes a list of Audio objects to mono by averaging all channels.

Arguments:
  • audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform. Shape: (num_channels, num_samples).
Returns:

List[Audio]: The list of audio objects with a mono waveform averaged from all channels. Shape: (num_samples).

def select_channel_from_audios( audios: List[senselab.audio.data_structures.audio.Audio], channel_index: int) -> List[senselab.audio.data_structures.audio.Audio]:
 90def select_channel_from_audios(audios: List[Audio], channel_index: int) -> List[Audio]:
 91    """Selects a specific channel from a list of Audio objects.
 92
 93    Args:
 94        audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform.
 95                              Shape: (num_channels, num_samples).
 96        channel_index (int): The index of the channel to select.
 97
 98    Returns:
 99        List[Audio]: The list of audio objects with the selected channel. Shape: (1, num_samples).
100    """
101    mono_channel_audios = []
102    for audio in audios:
103        if audio.waveform.size(0) <= channel_index:  # should consider how much sense negative values make
104            raise ValueError("channel_index should be valid")
105
106        mono_channel_audios.append(
107            Audio(
108                waveform=audio.waveform[channel_index, :],
109                sampling_rate=audio.sampling_rate,
110                metadata=audio.metadata.copy(),
111            )
112        )
113    return mono_channel_audios

Selects a specific channel from a list of Audio objects.

Arguments:
  • audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform. Shape: (num_channels, num_samples).
  • channel_index (int): The index of the channel to select.
Returns:

List[Audio]: The list of audio objects with the selected channel. Shape: (1, num_samples).

def chunk_audios( data: List[Tuple[senselab.audio.data_structures.audio.Audio, Tuple[float, float]]]) -> List[senselab.audio.data_structures.audio.Audio]:
116def chunk_audios(data: List[Tuple[Audio, Tuple[float, float]]]) -> List[Audio]:
117    """Chunks the input audios based on the start and end timestamp.
118
119    Args:
120        data: List of tuples containing an Audio object and a tuple with start and end (in seconds) for chunking.
121
122    Returns:
123        List of Audios that have been chunked based on the provided timestamps
124
125    Todo:
126        Do we really need both chunk_audios and extract_segments?
127    """
128    chunked_audios = []
129
130    for audio, timestamps in data:
131        start, end = timestamps
132        if start < 0:
133            raise ValueError("Start time must be greater than or equal to 0.")
134        duration = audio.waveform.shape[1] / audio.sampling_rate
135        if end > duration:
136            raise ValueError(f"End time must be less than the duration of the audio file ({duration} seconds).")
137        start_sample = int(start * audio.sampling_rate)
138        end_sample = int(end * audio.sampling_rate)
139        chunked_waveform = audio.waveform[:, start_sample:end_sample]
140
141        chunked_audios.append(
142            Audio(
143                waveform=chunked_waveform,
144                sampling_rate=audio.sampling_rate,
145                metadata=audio.metadata.copy(),
146            )
147        )
148    return chunked_audios

Chunks the input audios based on the start and end timestamp.

Arguments:
  • data: List of tuples containing an Audio object and a tuple with start and end (in seconds) for chunking.
Returns:

List of Audios that have been chunked based on the provided timestamps

Todo:

Do we really need both chunk_audios and extract_segments?

def extract_segments( data: List[Tuple[senselab.audio.data_structures.audio.Audio, List[Tuple[float, float]]]]) -> List[List[senselab.audio.data_structures.audio.Audio]]:
151def extract_segments(data: List[Tuple[Audio, List[Tuple[float, float]]]]) -> List[List[Audio]]:
152    """Extracts segments from an audio file.
153
154    Args:
155        data: List of tuples containing an Audio object and a list of tuples with start
156            and end (in seconds) for chunking.
157
158    Returns:
159        List of lists of Audios that have been chunked based on the provided timestamps.
160    """
161    extracted_segments = []
162    for audio, timestamps in data:
163        segments_data = [(audio, ts) for ts in timestamps]
164        single_audio_segments = chunk_audios(segments_data)
165        extracted_segments.append(single_audio_segments)
166    return extracted_segments

Extracts segments from an audio file.

Arguments:
  • data: List of tuples containing an Audio object and a list of tuples with start and end (in seconds) for chunking.
Returns:

List of lists of Audios that have been chunked based on the provided timestamps.

def pad_audios( audios: List[senselab.audio.data_structures.audio.Audio], desired_samples: int) -> List[senselab.audio.data_structures.audio.Audio]:
169def pad_audios(audios: List[Audio], desired_samples: int) -> List[Audio]:
170    """Pads the audio segment to the desired length.
171
172    Args:
173        audios: The list of audio objects to be padded.
174        desired_samples: The desired length (in samples) for the padded audio.
175
176    Returns:
177        A new Audio object with the waveform padded to the desired length.
178    """
179    padded_audios = []
180    for audio in audios:
181        current_samples = audio.waveform.shape[1]
182
183        if current_samples >= desired_samples:
184            return [audio]
185
186        padding_needed = desired_samples - current_samples
187        padded_waveform = torch.nn.functional.pad(audio.waveform, (0, padding_needed))
188        padded_audio = Audio(
189            waveform=padded_waveform,
190            sampling_rate=audio.sampling_rate,
191            metadata=audio.metadata.copy(),
192        )
193        padded_audios.append(padded_audio)
194    return padded_audios

Pads the audio segment to the desired length.

Arguments:
  • audios: The list of audio objects to be padded.
  • desired_samples: The desired length (in samples) for the padded audio.
Returns:

A new Audio object with the waveform padded to the desired length.

def evenly_segment_audios( audios: List[senselab.audio.data_structures.audio.Audio], segment_length: float, pad_last_segment: bool = True) -> List[List[senselab.audio.data_structures.audio.Audio]]:
197def evenly_segment_audios(
198    audios: List[Audio], segment_length: float, pad_last_segment: bool = True
199) -> List[List[Audio]]:
200    """Segments multiple audio files into evenly sized segments with optional padding for the last segment.
201
202    Args:
203        audios: The list of Audio objects to be segmented.
204        segment_length: The desired length of each segment in seconds.
205        pad_last_segment: Whether to pad the last segment to the full segment length (default is False).
206
207    Returns:
208        List of Audio objects that have been segmented.
209    """
210    audios_and_segment_timestamps = []
211    for i, audio in enumerate(audios):
212        total_duration = audio.waveform.shape[1] / audio.sampling_rate
213        segment_samples = int(segment_length * audio.sampling_rate)
214
215        # Create a list of tuples with start and end times for each segment
216        timestamps = [
217            (i * segment_length, (i + 1) * segment_length) for i in range(int(total_duration // segment_length))
218        ]
219        if total_duration % segment_length != 0:
220            timestamps.append((total_duration - (total_duration % segment_length), total_duration))
221        audios_and_segment_timestamps.append((audio, timestamps))
222
223    audio_segment_lists = extract_segments([(audio, timestamps)])
224
225    for i, audio_segment_list in enumerate(audio_segment_lists):
226        if pad_last_segment and len(audio_segment_list) > 0:
227            last_segment = audio_segment_list[-1]
228            if last_segment.waveform.shape[1] < segment_samples:
229                audio_segment_lists[i][-1] = pad_audios([last_segment], segment_samples)[0]
230
231    return audio_segment_lists

Segments multiple audio files into evenly sized segments with optional padding for the last segment.

Arguments:
  • audios: The list of Audio objects to be segmented.
  • segment_length: The desired length of each segment in seconds.
  • pad_last_segment: Whether to pad the last segment to the full segment length (default is False).
Returns:

List of Audio objects that have been segmented.

def concatenate_audios( audios: List[senselab.audio.data_structures.audio.Audio]) -> senselab.audio.data_structures.audio.Audio:
234def concatenate_audios(audios: List[Audio]) -> Audio:
235    """Concatenates all audios in the list, ensuring they have the same sampling rate and shape.
236
237    Args:
238        audios: List of Audio objects to concatenate.
239
240    Returns:
241        A single Audio object that is the concatenation of all input audios.
242
243    Raises:
244        ValueError: If the audios do not all have the same sampling rate or shape.
245    """
246    if not audios:
247        raise ValueError("The input list is empty. Please provide a list with at least one Audio object.")
248
249    sampling_rate = audios[0].sampling_rate
250    num_channels = audios[0].waveform.shape[0]
251
252    for audio in audios:
253        if audio.sampling_rate != sampling_rate:
254            raise ValueError("All audios must have the same sampling rate to concatenate.")
255        if audio.waveform.shape[0] != num_channels:
256            raise ValueError("All audios must have the same number of channels (mono or stereo) to concatenate.")
257
258    concatenated_waveform = torch.cat([audio.waveform.cpu() for audio in audios], dim=1)
259
260    # TODO: do we want to concatenate metadata? TBD
261
262    return Audio(
263        waveform=concatenated_waveform,
264        sampling_rate=sampling_rate,
265    )

Concatenates all audios in the list, ensuring they have the same sampling rate and shape.

Arguments:
  • audios: List of Audio objects to concatenate.
Returns:

A single Audio object that is the concatenation of all input audios.

Raises:
  • ValueError: If the audios do not all have the same sampling rate or shape.