senselab.audio.tasks.plotting.plotting

This module contains functions for plotting audio-related data.

  1"""This module contains functions for plotting audio-related data."""
  2
  3from typing import Any, Dict, Tuple, Union
  4
  5import matplotlib.pyplot as plt
  6import numpy as np
  7import torch
  8from matplotlib import rc_context
  9from matplotlib.figure import Figure
 10from mpl_toolkits.axes_grid1 import make_axes_locatable
 11
 12from senselab.audio.data_structures import Audio
 13from senselab.utils.data_structures import logger
 14
 15# ---------------------------
 16# Plot context & scaling
 17# ---------------------------
 18
 19_Context = Union[str, float]  # "auto" | "small" | "medium" | "large" | float scale
 20
 21
 22def _detect_screen_resolution() -> Tuple[int, int]:
 23    """Best-effort screen resolution detection. Falls back to 1920x1080."""
 24    # Try TkAgg
 25    try:
 26        mgr = plt.get_current_fig_manager()
 27        win = getattr(mgr, "window", None)
 28        if win is not None and hasattr(win, "winfo_screenwidth"):
 29            return int(win.winfo_screenwidth()), int(win.winfo_screenheight())
 30    except Exception:
 31        pass
 32    # Try Qt
 33    try:
 34        from PyQt5 import QtWidgets  # type: ignore
 35
 36        app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([])
 37        screen = app.primaryScreen()
 38        size = screen.size()
 39        return int(size.width()), int(size.height())
 40    except Exception:
 41        pass
 42    # Fallback
 43    return 1920, 1080
 44
 45
 46def _context_scale_from_resolution() -> float:
 47    """Map screen width → a sensible scale factor."""
 48    width, _ = _detect_screen_resolution()
 49    # Simple, readable buckets
 50    if width <= 1366:
 51        return 0.9
 52    if width <= 1920:
 53        return 1.0
 54    if width <= 2560:
 55        return 1.25
 56    if width <= 3840:
 57        return 1.5
 58    return 2.0
 59
 60
 61def _resolve_scale(context: _Context) -> float:
 62    if isinstance(context, (int, float)):
 63        return float(context)
 64    ctx = str(context).lower()
 65    if ctx == "auto":
 66        return _context_scale_from_resolution()
 67    if ctx in ("paper", "small"):
 68        return 0.9
 69    if ctx in ("notebook", "medium"):
 70        return 1.0
 71    if ctx in ("talk", "large"):
 72        return 1.3
 73    # Default
 74    return 1.0
 75
 76
 77def _rc_for_scale(scale: float) -> Dict[str, Any]:
 78    """Return rcParams tuned for the given scale (seaborn-like)."""
 79    base = 10.0 * scale
 80    return {
 81        "font.size": base,
 82        "axes.titlesize": base * 1.2,
 83        "axes.labelsize": base,
 84        "xtick.labelsize": base * 0.9,
 85        "ytick.labelsize": base * 0.9,
 86        "legend.fontsize": base * 0.95,
 87        "lines.linewidth": 1.25 * scale,
 88        "grid.linewidth": 0.8 * scale,
 89        "axes.linewidth": 0.8 * scale,
 90        "figure.titlesize": base * 1.3,
 91    }
 92
 93
 94# ---------------------------
 95# Helpers
 96# ---------------------------
 97
 98
 99def _power_to_db(spectrogram: np.ndarray, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0) -> np.ndarray:
100    """Converts a power spectrogram (amplitude squared) to decibel (dB) units."""
101    S = np.asarray(spectrogram)
102
103    if amin <= 0:
104        raise ValueError("amin must be strictly positive")
105
106    if np.issubdtype(S.dtype, np.complexfloating):
107        logger.warning(
108            "_power_to_db was called on complex input so phase information will be discarded. "
109            "To suppress this warning, call power_to_db(np.abs(D)**2) instead.",
110            stacklevel=2,
111        )
112        magnitude = np.abs(S)
113    else:
114        magnitude = S
115
116    ref_value = ref(magnitude) if callable(ref) else np.abs(ref)
117    log_spec: np.ndarray = 10.0 * np.log10(np.maximum(amin, magnitude))
118    log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
119
120    if top_db is not None:
121        if top_db < 0:
122            raise ValueError("top_db must be non-negative")
123        log_spec = np.maximum(log_spec, log_spec.max() - top_db)
124
125    return log_spec
126
127
128# ---------------------------
129# Public API
130# ---------------------------
131
132
133def plot_waveform(
134    audio: Audio,
135    title: str = "Waveform",
136    fast: bool = False,
137    *,
138    context: _Context = "auto",
139    figsize: Tuple[float, float] | None = None,
140) -> Figure:
141    """Plot the time-domain waveform of an `Audio` object and return the Figure.
142
143    The plot is automatically scaled for readability using a *context* scale
144    (similar to seaborn). Use `fast=True` to lightly decimate the signal for
145    quicker rendering on very long waveforms.
146
147    Args:
148        audio (Audio):
149            Input audio containing `.waveform` (shape `[C, T]`) and `.sampling_rate`.
150        title (str, optional):
151            Figure title. Defaults to `"Waveform"`.
152        fast (bool, optional):
153            If `True`, plots a 10× downsampled view for speed. Defaults to `False`.
154        context (_Context, optional):
155            Size preset or numeric scale. Accepted values:
156              * `"auto"` (detect from screen), `"small"`, `"medium"`, `"large"`,
157              * or a float scale factor (e.g., `1.25`). Defaults to `"auto"`.
158        figsize (tuple[float, float] | None, optional):
159            Base `(width, height)` in inches **before** context scaling.
160            Defaults to `(12, 2×channels)`.
161
162    Returns:
163        matplotlib.figure.Figure: The created figure (also displayed).
164
165    Example:
166        >>> from pathlib import Path
167        >>> from senselab.audio.data_structures import Audio
168        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
169        >>> fig = plot_waveform(a1, title="Sample 1", fast=True, context="medium")
170        >>> # fig.savefig("waveform.png")  # optional
171    """
172    waveform = audio.waveform
173    sample_rate = audio.sampling_rate
174
175    if fast:
176        waveform = waveform[..., ::10]
177
178    num_channels, num_frames = waveform.shape
179    time_axis = torch.linspace(0, num_frames / sample_rate, num_frames)
180
181    scale = _resolve_scale(context)
182    rc = _rc_for_scale(scale)
183    if figsize is None:
184        base = (12.0, max(2.0 * num_channels, 2.5))
185    else:
186        base = figsize
187    scaled_size = (base[0] * scale, base[1] * scale)
188
189    with rc_context(rc):
190        fig, axes = plt.subplots(num_channels, 1, figsize=scaled_size, sharex=True)
191        if num_channels == 1:
192            axes = [axes]  # ensure iterable
193        for c, ax in enumerate(axes):
194            ax.plot(time_axis.numpy(), waveform[c].cpu().numpy())
195            ax.set_ylabel(f"Ch {c + 1}")
196            ax.grid(True, alpha=0.3)
197        fig.suptitle(title)
198        axes[-1].set_xlabel("Time [s]")
199        fig.tight_layout(rect=(0, 0, 1, 0.96))
200        plt.show(block=False)
201        return fig
202
203
204def plot_specgram(
205    audio: Audio,
206    mel_scale: bool = False,
207    title: str = "Spectrogram",
208    *,
209    context: _Context = "auto",
210    figsize: Tuple[float, float] | None = None,
211    **spect_kwargs: Any,  # noqa: ANN401
212) -> Figure:
213    """Plot a (mel-)spectrogram for a **mono** `Audio` object and return the Figure.
214
215    Internally calls senselab's torchaudio-based extractors:
216    `extract_spectrogram_from_audios` or `extract_mel_spectrogram_from_audios`.
217    The function expects a 2D spectrogram `[freq_bins, time_frames]`; multi-channel
218    inputs should be downmixed beforehand.
219
220    Args:
221        audio (Audio):
222            Input **mono** audio. If multi-channel, downmix first.
223        mel_scale (bool, optional):
224            If `True`, plots a mel spectrogram; otherwise linear frequency. Defaults to `False`.
225        title (str, optional):
226            Figure title. Defaults to `"Spectrogram"`.
227        context (_Context, optional):
228            Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float).
229            Defaults to `"auto"`.
230        figsize (tuple[float, float] | None, optional):
231            Base `(width, height)` in inches **before** context scaling. Defaults to `(10, 4)`.
232        **spect_kwargs:
233            Passed to the underlying extractor (e.g., `n_fft=1024`, `hop_length=256`,
234            `n_mels=80`, `win_length=1024`, `f_min=0`, `f_max=None`).
235
236    Returns:
237        matplotlib.figure.Figure: The created figure (also displayed).
238
239    Raises:
240        ValueError: If spectrogram extraction fails, contains NaNs, or the result is not 2D.
241
242    Example (linear spectrogram):
243        >>> from pathlib import Path
244        >>> from senselab.audio.data_structures import Audio
245        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
246        >>> fig = plot_specgram(a1, mel_scale=False, n_fft=1024, hop_length=256)
247        >>> # fig.savefig("spec.png")
248
249    Example (mel spectrogram):
250        >>> from pathlib import Path
251        >>> from senselab.audio.data_structures import Audio
252        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
253        >>> fig = plot_specgram(a1, mel_scale=True, n_mels=80, n_fft=1024, hop_length=256)
254    """
255    # Extract the spectrogram
256    if mel_scale:
257        from senselab.audio.tasks.features_extraction.torchaudio import (
258            extract_mel_spectrogram_from_audios,
259        )
260
261        spectrogram = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"]
262        y_axis_label = "Mel frequency (bins)"
263    else:
264        from senselab.audio.tasks.features_extraction.torchaudio import (
265            extract_spectrogram_from_audios,
266        )
267
268        spectrogram = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"]
269        y_axis_label = "Frequency [Hz]"
270
271    # ---- Guard against invalid/short-audio outputs (must be exactly this phrase)
272    if not torch.is_tensor(spectrogram):
273        raise ValueError("Spectrogram extraction failed")
274    if spectrogram.ndim == 0 or spectrogram.numel() == 0:
275        raise ValueError("Spectrogram extraction failed")
276    if spectrogram.dtype.is_floating_point and torch.isnan(spectrogram).any():
277        raise ValueError("Spectrogram extraction failed")
278
279    if spectrogram.dim() != 2:
280        raise ValueError(
281            "Spectrogram must be a 2D tensor. Got shape: {}".format(spectrogram.shape),
282            "Please make sure the input audio is mono.",
283        )
284
285    # Determine time and frequency scale
286    # num_frames = spectrogram.size(1)
287    num_freq_bins = spectrogram.size(0)
288
289    # Time axis in seconds
290    duration_sec = audio.waveform.size(-1) / audio.sampling_rate
291    time_axis_start = 0.0
292    time_axis_end = float(duration_sec)
293
294    # Frequency axis
295    if mel_scale:
296        freq_start, freq_end = 0.0, float(num_freq_bins - 1)
297    else:
298        freq_start, freq_end = 0.0, float(audio.sampling_rate / 2)
299
300    scale = _resolve_scale(context)
301    rc = _rc_for_scale(scale)
302    if figsize is None:
303        base = (10.0, 4.0)
304    else:
305        base = figsize
306    scaled_size = (base[0] * scale, base[1] * scale)
307
308    with rc_context(rc):
309        fig = plt.figure(figsize=scaled_size)
310        plt.imshow(
311            _power_to_db(spectrogram.cpu().numpy()),
312            aspect="auto",
313            origin="lower",
314            extent=(time_axis_start, time_axis_end, freq_start, freq_end),
315            cmap="viridis",
316        )
317        plt.colorbar(label="Magnitude (dB)")
318        plt.title(title)
319        plt.ylabel(y_axis_label)
320        plt.xlabel("Time [s]")
321        plt.tight_layout()
322        plt.show(block=False)
323        return fig
324
325
326def plot_waveform_and_specgram(
327    audio: Audio,
328    *,
329    title: str = "Waveform + Spectrogram",
330    mel_scale: bool = False,
331    fast_wave: bool = False,
332    context: "_Context" = "auto",
333    figsize: Tuple[float, float] | None = None,
334    **spect_kwargs: Any,  # noqa: ANN401  # forwarded to spectrogram extraction
335) -> Figure:
336    """Stacked layout: waveform (top) and **mono** spectrogram (bottom). Returns the Figure.
337
338    The waveform can be drawn in a faster, lightly decimated mode for long signals.
339    Spectrogram extraction is delegated to senselab's torchaudio-based utilities
340    and requires mono input.
341
342    Args:
343        audio (Audio):
344            Input audio. **Spectrogram requires mono**; downmix multi-channel first.
345        title (str, optional):
346            Overall figure title. Defaults to `"Waveform + Spectrogram"`.
347        mel_scale (bool, optional):
348            If `True`, bottom panel is a mel spectrogram; otherwise linear frequency. Defaults to `False`.
349        fast_wave (bool, optional):
350            If `True`, waveform panel is downsampled for speed. Defaults to `False`.
351        context (_Context, optional):
352            Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float).
353            Defaults to `"auto"`.
354        figsize (tuple[float, float] | None, optional):
355            Base `(width, height)` in inches **before** context scaling. Defaults to a balanced height.
356        **spect_kwargs:
357            Forwarded to the underlying spectrogram extractor (e.g., `n_fft`, `hop_length`, `n_mels`).
358
359    Returns:
360        matplotlib.figure.Figure: The created figure (also displayed).
361
362    Raises:
363        ValueError: If audio is not mono, or spectrogram extraction fails.
364
365    Example:
366        >>> from pathlib import Path
367        >>> from senselab.audio.data_structures import Audio
368        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
369        >>> fig = plot_waveform_and_specgram(
370        ...     a1,
371        ...     mel_scale=True,
372        ...     fast_wave=True,
373        ...     context="large",
374        ...     n_fft=1024,
375        ...     hop_length=256,
376        ...     n_mels=80,
377        ... )
378        >>> # fig.savefig("wave_plus_mel.png")
379    """
380    # ---- Core timing info from ORIGINAL (non-decimated) data
381    sr = audio.sampling_rate
382    orig_num_frames = int(audio.waveform.size(-1))
383    duration_sec = orig_num_frames / sr
384    t0, t1 = 0.0, float(duration_sec)
385
386    # ---- Prepare waveform (optionally decimated for speed)
387    waveform = audio.waveform
388    if fast_wave:
389        waveform = waveform[..., ::10]  # decimate samples
390    num_channels, num_frames = waveform.shape
391    time_axis = np.linspace(0.0, duration_sec, num_frames, endpoint=False)
392
393    # ---- Guardrail: spectrogram plotting requires mono input
394    if audio.waveform.shape[0] != 1:
395        raise ValueError("Only mono audio is supported for spectrogram plotting")
396
397    # ---- Spectrogram (2D tensor: [freq_bins, time_frames])
398    if mel_scale:
399        from senselab.audio.tasks.features_extraction.torchaudio import (
400            extract_mel_spectrogram_from_audios,
401        )
402
403        spec = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"]
404        ylab = "Mel bins"
405        f0, f1 = 0.0, float(spec.size(0) - 1) if torch.is_tensor(spec) and spec.ndim >= 1 else (0.0, 0.0)
406        spec_title = "Mel Spectrogram"
407    else:
408        from senselab.audio.tasks.features_extraction.torchaudio import (
409            extract_spectrogram_from_audios,
410        )
411
412        spec = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"]
413        ylab = "Frequency [Hz]"
414        f0, f1 = 0.0, float(sr / 2)
415        spec_title = "Spectrogram"
416
417    # ---- Guardrails for short/invalid outputs (exact phrase expected by tests)
418    if not torch.is_tensor(spec):
419        raise ValueError("Spectrogram extraction failed")
420    if spec.ndim == 0 or spec.numel() == 0:
421        raise ValueError("Spectrogram extraction failed")
422    if spec.dtype.is_floating_point and torch.isnan(spec).any():
423        raise ValueError("Spectrogram extraction failed")
424
425    # We require a 2D (F x T) spectrogram. Anything else → fail (don’t auto-pick channels).
426    if spec.ndim != 2:
427        raise ValueError("Spectrogram extraction failed")
428
429    # ---- Layout & context
430    scale = _resolve_scale(context)
431    rc = _rc_for_scale(scale)
432    if figsize is None:
433        base_h = max(2.0, 0.9 * num_channels) + 4.0  # waveform height + spectrogram
434        base = (12.0, base_h)
435    else:
436        base = figsize
437    size = (base[0] * scale, base[1] * scale)
438
439    with rc_context(rc):
440        fig, (ax_wav, ax_spec) = plt.subplots(2, 1, figsize=size, sharex=True, gridspec_kw={"height_ratios": [1, 2]})
441
442        # ---- Waveform (top)
443        if num_channels == 1:
444            ax_wav.plot(time_axis, waveform[0].cpu().numpy())
445            ax_wav.set_ylabel("Amp")
446        else:
447            for c in range(num_channels):
448                ax_wav.plot(time_axis, waveform[c].cpu().numpy(), alpha=0.9 if c == 0 else 0.7)
449            ax_wav.set_ylabel("Amp (multi-ch)")
450        ax_wav.grid(True, alpha=0.3)
451        ax_wav.set_title("Waveform")
452
453        # ---- Spectrogram (bottom)
454        im = ax_spec.imshow(
455            _power_to_db(spec.cpu().numpy()),
456            aspect="auto",
457            origin="lower",
458            extent=(t0, t1, f0, f1),
459            cmap="viridis",
460        )
461        ax_spec.set_ylabel(ylab)
462        ax_spec.set_xlabel("Time [s]")
463        ax_spec.set_title(spec_title)
464
465        # Keep both axes aligned in time
466        ax_wav.set_xlim(t0, t1)
467        ax_spec.set_xlim(t0, t1)
468
469        # ---- Horizontal colorbar below the spectrogram
470        divider = make_axes_locatable(ax_spec)
471        cax = divider.append_axes("bottom", size="5%", pad=0.6)
472        cbar = fig.colorbar(im, cax=cax, orientation="horizontal")
473        cbar.set_label("Magnitude (dB)")
474
475        fig.suptitle(title)
476        fig.tight_layout(rect=(0, 0, 1, 0.96))
477        plt.show(block=False)
478        return fig
479
480
481def play_audio(audio: Audio) -> None:
482    """Play an `Audio` object inline (Jupyter/IPython), supporting 1–2 channels.
483
484    Uses `IPython.display.Audio` to render audio widgets in notebooks. For more
485    than two channels, downmix first.
486
487    Args:
488        audio (Audio):
489            Input audio to play (mono or stereo). Sampling rate is preserved.
490
491    Raises:
492        ValueError: If the waveform has more than 2 channels.
493
494    Example:
495        >>> from pathlib import Path
496        >>> from senselab.audio.data_structures import Audio
497        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
498        >>> play_audio(a1)
499    """
500    from IPython.display import Audio as DisplayAudio
501    from IPython.display import display
502
503    waveform = audio.waveform.cpu().numpy()
504    sample_rate = audio.sampling_rate
505
506    num_channels = waveform.shape[0]
507    if num_channels == 1:
508        display(DisplayAudio(waveform[0], rate=sample_rate))
509    elif num_channels == 2:
510        display(DisplayAudio((waveform[0], waveform[1]), rate=sample_rate))
511    else:
512        raise ValueError("Waveform with more than 2 channels is not supported.")
def plot_waveform( audio: senselab.audio.data_structures.audio.Audio, title: str = 'Waveform', fast: bool = False, *, context: Union[str, float] = 'auto', figsize: Optional[Tuple[float, float]] = None) -> matplotlib.figure.Figure:
134def plot_waveform(
135    audio: Audio,
136    title: str = "Waveform",
137    fast: bool = False,
138    *,
139    context: _Context = "auto",
140    figsize: Tuple[float, float] | None = None,
141) -> Figure:
142    """Plot the time-domain waveform of an `Audio` object and return the Figure.
143
144    The plot is automatically scaled for readability using a *context* scale
145    (similar to seaborn). Use `fast=True` to lightly decimate the signal for
146    quicker rendering on very long waveforms.
147
148    Args:
149        audio (Audio):
150            Input audio containing `.waveform` (shape `[C, T]`) and `.sampling_rate`.
151        title (str, optional):
152            Figure title. Defaults to `"Waveform"`.
153        fast (bool, optional):
154            If `True`, plots a 10× downsampled view for speed. Defaults to `False`.
155        context (_Context, optional):
156            Size preset or numeric scale. Accepted values:
157              * `"auto"` (detect from screen), `"small"`, `"medium"`, `"large"`,
158              * or a float scale factor (e.g., `1.25`). Defaults to `"auto"`.
159        figsize (tuple[float, float] | None, optional):
160            Base `(width, height)` in inches **before** context scaling.
161            Defaults to `(12, 2×channels)`.
162
163    Returns:
164        matplotlib.figure.Figure: The created figure (also displayed).
165
166    Example:
167        >>> from pathlib import Path
168        >>> from senselab.audio.data_structures import Audio
169        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
170        >>> fig = plot_waveform(a1, title="Sample 1", fast=True, context="medium")
171        >>> # fig.savefig("waveform.png")  # optional
172    """
173    waveform = audio.waveform
174    sample_rate = audio.sampling_rate
175
176    if fast:
177        waveform = waveform[..., ::10]
178
179    num_channels, num_frames = waveform.shape
180    time_axis = torch.linspace(0, num_frames / sample_rate, num_frames)
181
182    scale = _resolve_scale(context)
183    rc = _rc_for_scale(scale)
184    if figsize is None:
185        base = (12.0, max(2.0 * num_channels, 2.5))
186    else:
187        base = figsize
188    scaled_size = (base[0] * scale, base[1] * scale)
189
190    with rc_context(rc):
191        fig, axes = plt.subplots(num_channels, 1, figsize=scaled_size, sharex=True)
192        if num_channels == 1:
193            axes = [axes]  # ensure iterable
194        for c, ax in enumerate(axes):
195            ax.plot(time_axis.numpy(), waveform[c].cpu().numpy())
196            ax.set_ylabel(f"Ch {c + 1}")
197            ax.grid(True, alpha=0.3)
198        fig.suptitle(title)
199        axes[-1].set_xlabel("Time [s]")
200        fig.tight_layout(rect=(0, 0, 1, 0.96))
201        plt.show(block=False)
202        return fig

Plot the time-domain waveform of an Audio object and return the Figure.

The plot is automatically scaled for readability using a context scale (similar to seaborn). Use fast=True to lightly decimate the signal for quicker rendering on very long waveforms.

Arguments:
  • audio (Audio): Input audio containing .waveform (shape [C, T]) and .sampling_rate.
  • title (str, optional): Figure title. Defaults to "Waveform".
  • fast (bool, optional): If True, plots a 10× downsampled view for speed. Defaults to False.
  • context (_Context, optional): Size preset or numeric scale. Accepted values:
    • "auto" (detect from screen), "small", "medium", "large",
    • or a float scale factor (e.g., 1.25). Defaults to "auto".
  • figsize (tuple[float, float] | None, optional): Base (width, height) in inches before context scaling. Defaults to (12, 2×channels).
Returns:

matplotlib.figure.Figure: The created figure (also displayed).

Example:
>>> from pathlib import Path
>>> from senselab.audio.data_structures import Audio
>>> a1 = Audio(filepath=Path("sample1.wav").resolve())
>>> fig = plot_waveform(a1, title="Sample 1", fast=True, context="medium")
>>> # fig.savefig("waveform.png")  # optional
def plot_specgram( audio: senselab.audio.data_structures.audio.Audio, mel_scale: bool = False, title: str = 'Spectrogram', *, context: Union[str, float] = 'auto', figsize: Optional[Tuple[float, float]] = None, **spect_kwargs: Any) -> matplotlib.figure.Figure:
205def plot_specgram(
206    audio: Audio,
207    mel_scale: bool = False,
208    title: str = "Spectrogram",
209    *,
210    context: _Context = "auto",
211    figsize: Tuple[float, float] | None = None,
212    **spect_kwargs: Any,  # noqa: ANN401
213) -> Figure:
214    """Plot a (mel-)spectrogram for a **mono** `Audio` object and return the Figure.
215
216    Internally calls senselab's torchaudio-based extractors:
217    `extract_spectrogram_from_audios` or `extract_mel_spectrogram_from_audios`.
218    The function expects a 2D spectrogram `[freq_bins, time_frames]`; multi-channel
219    inputs should be downmixed beforehand.
220
221    Args:
222        audio (Audio):
223            Input **mono** audio. If multi-channel, downmix first.
224        mel_scale (bool, optional):
225            If `True`, plots a mel spectrogram; otherwise linear frequency. Defaults to `False`.
226        title (str, optional):
227            Figure title. Defaults to `"Spectrogram"`.
228        context (_Context, optional):
229            Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float).
230            Defaults to `"auto"`.
231        figsize (tuple[float, float] | None, optional):
232            Base `(width, height)` in inches **before** context scaling. Defaults to `(10, 4)`.
233        **spect_kwargs:
234            Passed to the underlying extractor (e.g., `n_fft=1024`, `hop_length=256`,
235            `n_mels=80`, `win_length=1024`, `f_min=0`, `f_max=None`).
236
237    Returns:
238        matplotlib.figure.Figure: The created figure (also displayed).
239
240    Raises:
241        ValueError: If spectrogram extraction fails, contains NaNs, or the result is not 2D.
242
243    Example (linear spectrogram):
244        >>> from pathlib import Path
245        >>> from senselab.audio.data_structures import Audio
246        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
247        >>> fig = plot_specgram(a1, mel_scale=False, n_fft=1024, hop_length=256)
248        >>> # fig.savefig("spec.png")
249
250    Example (mel spectrogram):
251        >>> from pathlib import Path
252        >>> from senselab.audio.data_structures import Audio
253        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
254        >>> fig = plot_specgram(a1, mel_scale=True, n_mels=80, n_fft=1024, hop_length=256)
255    """
256    # Extract the spectrogram
257    if mel_scale:
258        from senselab.audio.tasks.features_extraction.torchaudio import (
259            extract_mel_spectrogram_from_audios,
260        )
261
262        spectrogram = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"]
263        y_axis_label = "Mel frequency (bins)"
264    else:
265        from senselab.audio.tasks.features_extraction.torchaudio import (
266            extract_spectrogram_from_audios,
267        )
268
269        spectrogram = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"]
270        y_axis_label = "Frequency [Hz]"
271
272    # ---- Guard against invalid/short-audio outputs (must be exactly this phrase)
273    if not torch.is_tensor(spectrogram):
274        raise ValueError("Spectrogram extraction failed")
275    if spectrogram.ndim == 0 or spectrogram.numel() == 0:
276        raise ValueError("Spectrogram extraction failed")
277    if spectrogram.dtype.is_floating_point and torch.isnan(spectrogram).any():
278        raise ValueError("Spectrogram extraction failed")
279
280    if spectrogram.dim() != 2:
281        raise ValueError(
282            "Spectrogram must be a 2D tensor. Got shape: {}".format(spectrogram.shape),
283            "Please make sure the input audio is mono.",
284        )
285
286    # Determine time and frequency scale
287    # num_frames = spectrogram.size(1)
288    num_freq_bins = spectrogram.size(0)
289
290    # Time axis in seconds
291    duration_sec = audio.waveform.size(-1) / audio.sampling_rate
292    time_axis_start = 0.0
293    time_axis_end = float(duration_sec)
294
295    # Frequency axis
296    if mel_scale:
297        freq_start, freq_end = 0.0, float(num_freq_bins - 1)
298    else:
299        freq_start, freq_end = 0.0, float(audio.sampling_rate / 2)
300
301    scale = _resolve_scale(context)
302    rc = _rc_for_scale(scale)
303    if figsize is None:
304        base = (10.0, 4.0)
305    else:
306        base = figsize
307    scaled_size = (base[0] * scale, base[1] * scale)
308
309    with rc_context(rc):
310        fig = plt.figure(figsize=scaled_size)
311        plt.imshow(
312            _power_to_db(spectrogram.cpu().numpy()),
313            aspect="auto",
314            origin="lower",
315            extent=(time_axis_start, time_axis_end, freq_start, freq_end),
316            cmap="viridis",
317        )
318        plt.colorbar(label="Magnitude (dB)")
319        plt.title(title)
320        plt.ylabel(y_axis_label)
321        plt.xlabel("Time [s]")
322        plt.tight_layout()
323        plt.show(block=False)
324        return fig

Plot a (mel-)spectrogram for a mono Audio object and return the Figure.

Internally calls senselab's torchaudio-based extractors: extract_spectrogram_from_audios or extract_mel_spectrogram_from_audios. The function expects a 2D spectrogram [freq_bins, time_frames]; multi-channel inputs should be downmixed beforehand.

Arguments:
  • audio (Audio): Input mono audio. If multi-channel, downmix first.
  • mel_scale (bool, optional): If True, plots a mel spectrogram; otherwise linear frequency. Defaults to False.
  • title (str, optional): Figure title. Defaults to "Spectrogram".
  • context (_Context, optional): Size preset or numeric scale ("auto", "small", "medium", "large", or float). Defaults to "auto".
  • figsize (tuple[float, float] | None, optional): Base (width, height) in inches before context scaling. Defaults to (10, 4).
  • **spect_kwargs: Passed to the underlying extractor (e.g., n_fft=1024, hop_length=256, n_mels=80, win_length=1024, f_min=0, f_max=None).
Returns:

matplotlib.figure.Figure: The created figure (also displayed).

Raises:
  • ValueError: If spectrogram extraction fails, contains NaNs, or the result is not 2D.

Example (linear spectrogram):

from pathlib import Path from senselab.audio.data_structures import Audio a1 = Audio(filepath=Path("sample1.wav").resolve()) fig = plot_specgram(a1, mel_scale=False, n_fft=1024, hop_length=256)

fig.savefig("spec.png")

Example (mel spectrogram):

from pathlib import Path from senselab.audio.data_structures import Audio a1 = Audio(filepath=Path("sample1.wav").resolve()) fig = plot_specgram(a1, mel_scale=True, n_mels=80, n_fft=1024, hop_length=256)

def plot_waveform_and_specgram( audio: senselab.audio.data_structures.audio.Audio, *, title: str = 'Waveform + Spectrogram', mel_scale: bool = False, fast_wave: bool = False, context: Union[str, float] = 'auto', figsize: Optional[Tuple[float, float]] = None, **spect_kwargs: Any) -> matplotlib.figure.Figure:
327def plot_waveform_and_specgram(
328    audio: Audio,
329    *,
330    title: str = "Waveform + Spectrogram",
331    mel_scale: bool = False,
332    fast_wave: bool = False,
333    context: "_Context" = "auto",
334    figsize: Tuple[float, float] | None = None,
335    **spect_kwargs: Any,  # noqa: ANN401  # forwarded to spectrogram extraction
336) -> Figure:
337    """Stacked layout: waveform (top) and **mono** spectrogram (bottom). Returns the Figure.
338
339    The waveform can be drawn in a faster, lightly decimated mode for long signals.
340    Spectrogram extraction is delegated to senselab's torchaudio-based utilities
341    and requires mono input.
342
343    Args:
344        audio (Audio):
345            Input audio. **Spectrogram requires mono**; downmix multi-channel first.
346        title (str, optional):
347            Overall figure title. Defaults to `"Waveform + Spectrogram"`.
348        mel_scale (bool, optional):
349            If `True`, bottom panel is a mel spectrogram; otherwise linear frequency. Defaults to `False`.
350        fast_wave (bool, optional):
351            If `True`, waveform panel is downsampled for speed. Defaults to `False`.
352        context (_Context, optional):
353            Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float).
354            Defaults to `"auto"`.
355        figsize (tuple[float, float] | None, optional):
356            Base `(width, height)` in inches **before** context scaling. Defaults to a balanced height.
357        **spect_kwargs:
358            Forwarded to the underlying spectrogram extractor (e.g., `n_fft`, `hop_length`, `n_mels`).
359
360    Returns:
361        matplotlib.figure.Figure: The created figure (also displayed).
362
363    Raises:
364        ValueError: If audio is not mono, or spectrogram extraction fails.
365
366    Example:
367        >>> from pathlib import Path
368        >>> from senselab.audio.data_structures import Audio
369        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
370        >>> fig = plot_waveform_and_specgram(
371        ...     a1,
372        ...     mel_scale=True,
373        ...     fast_wave=True,
374        ...     context="large",
375        ...     n_fft=1024,
376        ...     hop_length=256,
377        ...     n_mels=80,
378        ... )
379        >>> # fig.savefig("wave_plus_mel.png")
380    """
381    # ---- Core timing info from ORIGINAL (non-decimated) data
382    sr = audio.sampling_rate
383    orig_num_frames = int(audio.waveform.size(-1))
384    duration_sec = orig_num_frames / sr
385    t0, t1 = 0.0, float(duration_sec)
386
387    # ---- Prepare waveform (optionally decimated for speed)
388    waveform = audio.waveform
389    if fast_wave:
390        waveform = waveform[..., ::10]  # decimate samples
391    num_channels, num_frames = waveform.shape
392    time_axis = np.linspace(0.0, duration_sec, num_frames, endpoint=False)
393
394    # ---- Guardrail: spectrogram plotting requires mono input
395    if audio.waveform.shape[0] != 1:
396        raise ValueError("Only mono audio is supported for spectrogram plotting")
397
398    # ---- Spectrogram (2D tensor: [freq_bins, time_frames])
399    if mel_scale:
400        from senselab.audio.tasks.features_extraction.torchaudio import (
401            extract_mel_spectrogram_from_audios,
402        )
403
404        spec = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"]
405        ylab = "Mel bins"
406        f0, f1 = 0.0, float(spec.size(0) - 1) if torch.is_tensor(spec) and spec.ndim >= 1 else (0.0, 0.0)
407        spec_title = "Mel Spectrogram"
408    else:
409        from senselab.audio.tasks.features_extraction.torchaudio import (
410            extract_spectrogram_from_audios,
411        )
412
413        spec = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"]
414        ylab = "Frequency [Hz]"
415        f0, f1 = 0.0, float(sr / 2)
416        spec_title = "Spectrogram"
417
418    # ---- Guardrails for short/invalid outputs (exact phrase expected by tests)
419    if not torch.is_tensor(spec):
420        raise ValueError("Spectrogram extraction failed")
421    if spec.ndim == 0 or spec.numel() == 0:
422        raise ValueError("Spectrogram extraction failed")
423    if spec.dtype.is_floating_point and torch.isnan(spec).any():
424        raise ValueError("Spectrogram extraction failed")
425
426    # We require a 2D (F x T) spectrogram. Anything else → fail (don’t auto-pick channels).
427    if spec.ndim != 2:
428        raise ValueError("Spectrogram extraction failed")
429
430    # ---- Layout & context
431    scale = _resolve_scale(context)
432    rc = _rc_for_scale(scale)
433    if figsize is None:
434        base_h = max(2.0, 0.9 * num_channels) + 4.0  # waveform height + spectrogram
435        base = (12.0, base_h)
436    else:
437        base = figsize
438    size = (base[0] * scale, base[1] * scale)
439
440    with rc_context(rc):
441        fig, (ax_wav, ax_spec) = plt.subplots(2, 1, figsize=size, sharex=True, gridspec_kw={"height_ratios": [1, 2]})
442
443        # ---- Waveform (top)
444        if num_channels == 1:
445            ax_wav.plot(time_axis, waveform[0].cpu().numpy())
446            ax_wav.set_ylabel("Amp")
447        else:
448            for c in range(num_channels):
449                ax_wav.plot(time_axis, waveform[c].cpu().numpy(), alpha=0.9 if c == 0 else 0.7)
450            ax_wav.set_ylabel("Amp (multi-ch)")
451        ax_wav.grid(True, alpha=0.3)
452        ax_wav.set_title("Waveform")
453
454        # ---- Spectrogram (bottom)
455        im = ax_spec.imshow(
456            _power_to_db(spec.cpu().numpy()),
457            aspect="auto",
458            origin="lower",
459            extent=(t0, t1, f0, f1),
460            cmap="viridis",
461        )
462        ax_spec.set_ylabel(ylab)
463        ax_spec.set_xlabel("Time [s]")
464        ax_spec.set_title(spec_title)
465
466        # Keep both axes aligned in time
467        ax_wav.set_xlim(t0, t1)
468        ax_spec.set_xlim(t0, t1)
469
470        # ---- Horizontal colorbar below the spectrogram
471        divider = make_axes_locatable(ax_spec)
472        cax = divider.append_axes("bottom", size="5%", pad=0.6)
473        cbar = fig.colorbar(im, cax=cax, orientation="horizontal")
474        cbar.set_label("Magnitude (dB)")
475
476        fig.suptitle(title)
477        fig.tight_layout(rect=(0, 0, 1, 0.96))
478        plt.show(block=False)
479        return fig

Stacked layout: waveform (top) and mono spectrogram (bottom). Returns the Figure.

The waveform can be drawn in a faster, lightly decimated mode for long signals. Spectrogram extraction is delegated to senselab's torchaudio-based utilities and requires mono input.

Arguments:
  • audio (Audio): Input audio. Spectrogram requires mono; downmix multi-channel first.
  • title (str, optional): Overall figure title. Defaults to "Waveform + Spectrogram".
  • mel_scale (bool, optional): If True, bottom panel is a mel spectrogram; otherwise linear frequency. Defaults to False.
  • fast_wave (bool, optional): If True, waveform panel is downsampled for speed. Defaults to False.
  • context (_Context, optional): Size preset or numeric scale ("auto", "small", "medium", "large", or float). Defaults to "auto".
  • figsize (tuple[float, float] | None, optional): Base (width, height) in inches before context scaling. Defaults to a balanced height.
  • **spect_kwargs: Forwarded to the underlying spectrogram extractor (e.g., n_fft, hop_length, n_mels).
Returns:

matplotlib.figure.Figure: The created figure (also displayed).

Raises:
  • ValueError: If audio is not mono, or spectrogram extraction fails.
Example:
>>> from pathlib import Path
>>> from senselab.audio.data_structures import Audio
>>> a1 = Audio(filepath=Path("sample1.wav").resolve())
>>> fig = plot_waveform_and_specgram(
...     a1,
...     mel_scale=True,
...     fast_wave=True,
...     context="large",
...     n_fft=1024,
...     hop_length=256,
...     n_mels=80,
... )
>>> # fig.savefig("wave_plus_mel.png")
def play_audio(audio: senselab.audio.data_structures.audio.Audio) -> None:
482def play_audio(audio: Audio) -> None:
483    """Play an `Audio` object inline (Jupyter/IPython), supporting 1–2 channels.
484
485    Uses `IPython.display.Audio` to render audio widgets in notebooks. For more
486    than two channels, downmix first.
487
488    Args:
489        audio (Audio):
490            Input audio to play (mono or stereo). Sampling rate is preserved.
491
492    Raises:
493        ValueError: If the waveform has more than 2 channels.
494
495    Example:
496        >>> from pathlib import Path
497        >>> from senselab.audio.data_structures import Audio
498        >>> a1 = Audio(filepath=Path("sample1.wav").resolve())
499        >>> play_audio(a1)
500    """
501    from IPython.display import Audio as DisplayAudio
502    from IPython.display import display
503
504    waveform = audio.waveform.cpu().numpy()
505    sample_rate = audio.sampling_rate
506
507    num_channels = waveform.shape[0]
508    if num_channels == 1:
509        display(DisplayAudio(waveform[0], rate=sample_rate))
510    elif num_channels == 2:
511        display(DisplayAudio((waveform[0], waveform[1]), rate=sample_rate))
512    else:
513        raise ValueError("Waveform with more than 2 channels is not supported.")

Play an Audio object inline (Jupyter/IPython), supporting 1–2 channels.

Uses IPython.display.Audio to render audio widgets in notebooks. For more than two channels, downmix first.

Arguments:
  • audio (Audio): Input audio to play (mono or stereo). Sampling rate is preserved.
Raises:
  • ValueError: If the waveform has more than 2 channels.
Example:
>>> from pathlib import Path
>>> from senselab.audio.data_structures import Audio
>>> a1 = Audio(filepath=Path("sample1.wav").resolve())
>>> play_audio(a1)