senselab.audio.tasks.plotting.plotting
This module contains functions for plotting audio-related data.
1"""This module contains functions for plotting audio-related data.""" 2 3from typing import Any, Dict, Tuple, Union 4 5import matplotlib.pyplot as plt 6import numpy as np 7import torch 8from matplotlib import rc_context 9from matplotlib.figure import Figure 10from mpl_toolkits.axes_grid1 import make_axes_locatable 11 12from senselab.audio.data_structures import Audio 13from senselab.utils.data_structures import logger 14 15# --------------------------- 16# Plot context & scaling 17# --------------------------- 18 19_Context = Union[str, float] # "auto" | "small" | "medium" | "large" | float scale 20 21 22def _detect_screen_resolution() -> Tuple[int, int]: 23 """Best-effort screen resolution detection. Falls back to 1920x1080.""" 24 # Try TkAgg 25 try: 26 mgr = plt.get_current_fig_manager() 27 win = getattr(mgr, "window", None) 28 if win is not None and hasattr(win, "winfo_screenwidth"): 29 return int(win.winfo_screenwidth()), int(win.winfo_screenheight()) 30 except Exception: 31 pass 32 # Try Qt 33 try: 34 from PyQt5 import QtWidgets # type: ignore 35 36 app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([]) 37 screen = app.primaryScreen() 38 size = screen.size() 39 return int(size.width()), int(size.height()) 40 except Exception: 41 pass 42 # Fallback 43 return 1920, 1080 44 45 46def _context_scale_from_resolution() -> float: 47 """Map screen width → a sensible scale factor.""" 48 width, _ = _detect_screen_resolution() 49 # Simple, readable buckets 50 if width <= 1366: 51 return 0.9 52 if width <= 1920: 53 return 1.0 54 if width <= 2560: 55 return 1.25 56 if width <= 3840: 57 return 1.5 58 return 2.0 59 60 61def _resolve_scale(context: _Context) -> float: 62 if isinstance(context, (int, float)): 63 return float(context) 64 ctx = str(context).lower() 65 if ctx == "auto": 66 return _context_scale_from_resolution() 67 if ctx in ("paper", "small"): 68 return 0.9 69 if ctx in ("notebook", "medium"): 70 return 1.0 71 if ctx in ("talk", "large"): 72 return 1.3 73 # Default 74 return 1.0 75 76 77def _rc_for_scale(scale: float) -> Dict[str, Any]: 78 """Return rcParams tuned for the given scale (seaborn-like).""" 79 base = 10.0 * scale 80 return { 81 "font.size": base, 82 "axes.titlesize": base * 1.2, 83 "axes.labelsize": base, 84 "xtick.labelsize": base * 0.9, 85 "ytick.labelsize": base * 0.9, 86 "legend.fontsize": base * 0.95, 87 "lines.linewidth": 1.25 * scale, 88 "grid.linewidth": 0.8 * scale, 89 "axes.linewidth": 0.8 * scale, 90 "figure.titlesize": base * 1.3, 91 } 92 93 94# --------------------------- 95# Helpers 96# --------------------------- 97 98 99def _power_to_db(spectrogram: np.ndarray, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0) -> np.ndarray: 100 """Converts a power spectrogram (amplitude squared) to decibel (dB) units.""" 101 S = np.asarray(spectrogram) 102 103 if amin <= 0: 104 raise ValueError("amin must be strictly positive") 105 106 if np.issubdtype(S.dtype, np.complexfloating): 107 logger.warning( 108 "_power_to_db was called on complex input so phase information will be discarded. " 109 "To suppress this warning, call power_to_db(np.abs(D)**2) instead.", 110 stacklevel=2, 111 ) 112 magnitude = np.abs(S) 113 else: 114 magnitude = S 115 116 ref_value = ref(magnitude) if callable(ref) else np.abs(ref) 117 log_spec: np.ndarray = 10.0 * np.log10(np.maximum(amin, magnitude)) 118 log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) 119 120 if top_db is not None: 121 if top_db < 0: 122 raise ValueError("top_db must be non-negative") 123 log_spec = np.maximum(log_spec, log_spec.max() - top_db) 124 125 return log_spec 126 127 128# --------------------------- 129# Public API 130# --------------------------- 131 132 133def plot_waveform( 134 audio: Audio, 135 title: str = "Waveform", 136 fast: bool = False, 137 *, 138 context: _Context = "auto", 139 figsize: Tuple[float, float] | None = None, 140) -> Figure: 141 """Plot the time-domain waveform of an `Audio` object and return the Figure. 142 143 The plot is automatically scaled for readability using a *context* scale 144 (similar to seaborn). Use `fast=True` to lightly decimate the signal for 145 quicker rendering on very long waveforms. 146 147 Args: 148 audio (Audio): 149 Input audio containing `.waveform` (shape `[C, T]`) and `.sampling_rate`. 150 title (str, optional): 151 Figure title. Defaults to `"Waveform"`. 152 fast (bool, optional): 153 If `True`, plots a 10× downsampled view for speed. Defaults to `False`. 154 context (_Context, optional): 155 Size preset or numeric scale. Accepted values: 156 * `"auto"` (detect from screen), `"small"`, `"medium"`, `"large"`, 157 * or a float scale factor (e.g., `1.25`). Defaults to `"auto"`. 158 figsize (tuple[float, float] | None, optional): 159 Base `(width, height)` in inches **before** context scaling. 160 Defaults to `(12, 2×channels)`. 161 162 Returns: 163 matplotlib.figure.Figure: The created figure (also displayed). 164 165 Example: 166 >>> from pathlib import Path 167 >>> from senselab.audio.data_structures import Audio 168 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 169 >>> fig = plot_waveform(a1, title="Sample 1", fast=True, context="medium") 170 >>> # fig.savefig("waveform.png") # optional 171 """ 172 waveform = audio.waveform 173 sample_rate = audio.sampling_rate 174 175 if fast: 176 waveform = waveform[..., ::10] 177 178 num_channels, num_frames = waveform.shape 179 time_axis = torch.linspace(0, num_frames / sample_rate, num_frames) 180 181 scale = _resolve_scale(context) 182 rc = _rc_for_scale(scale) 183 if figsize is None: 184 base = (12.0, max(2.0 * num_channels, 2.5)) 185 else: 186 base = figsize 187 scaled_size = (base[0] * scale, base[1] * scale) 188 189 with rc_context(rc): 190 fig, axes = plt.subplots(num_channels, 1, figsize=scaled_size, sharex=True) 191 if num_channels == 1: 192 axes = [axes] # ensure iterable 193 for c, ax in enumerate(axes): 194 ax.plot(time_axis.numpy(), waveform[c].cpu().numpy()) 195 ax.set_ylabel(f"Ch {c + 1}") 196 ax.grid(True, alpha=0.3) 197 fig.suptitle(title) 198 axes[-1].set_xlabel("Time [s]") 199 fig.tight_layout(rect=(0, 0, 1, 0.96)) 200 plt.show(block=False) 201 return fig 202 203 204def plot_specgram( 205 audio: Audio, 206 mel_scale: bool = False, 207 title: str = "Spectrogram", 208 *, 209 context: _Context = "auto", 210 figsize: Tuple[float, float] | None = None, 211 **spect_kwargs: Any, # noqa: ANN401 212) -> Figure: 213 """Plot a (mel-)spectrogram for a **mono** `Audio` object and return the Figure. 214 215 Internally calls senselab's torchaudio-based extractors: 216 `extract_spectrogram_from_audios` or `extract_mel_spectrogram_from_audios`. 217 The function expects a 2D spectrogram `[freq_bins, time_frames]`; multi-channel 218 inputs should be downmixed beforehand. 219 220 Args: 221 audio (Audio): 222 Input **mono** audio. If multi-channel, downmix first. 223 mel_scale (bool, optional): 224 If `True`, plots a mel spectrogram; otherwise linear frequency. Defaults to `False`. 225 title (str, optional): 226 Figure title. Defaults to `"Spectrogram"`. 227 context (_Context, optional): 228 Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float). 229 Defaults to `"auto"`. 230 figsize (tuple[float, float] | None, optional): 231 Base `(width, height)` in inches **before** context scaling. Defaults to `(10, 4)`. 232 **spect_kwargs: 233 Passed to the underlying extractor (e.g., `n_fft=1024`, `hop_length=256`, 234 `n_mels=80`, `win_length=1024`, `f_min=0`, `f_max=None`). 235 236 Returns: 237 matplotlib.figure.Figure: The created figure (also displayed). 238 239 Raises: 240 ValueError: If spectrogram extraction fails, contains NaNs, or the result is not 2D. 241 242 Example (linear spectrogram): 243 >>> from pathlib import Path 244 >>> from senselab.audio.data_structures import Audio 245 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 246 >>> fig = plot_specgram(a1, mel_scale=False, n_fft=1024, hop_length=256) 247 >>> # fig.savefig("spec.png") 248 249 Example (mel spectrogram): 250 >>> from pathlib import Path 251 >>> from senselab.audio.data_structures import Audio 252 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 253 >>> fig = plot_specgram(a1, mel_scale=True, n_mels=80, n_fft=1024, hop_length=256) 254 """ 255 # Extract the spectrogram 256 if mel_scale: 257 from senselab.audio.tasks.features_extraction.torchaudio import ( 258 extract_mel_spectrogram_from_audios, 259 ) 260 261 spectrogram = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"] 262 y_axis_label = "Mel frequency (bins)" 263 else: 264 from senselab.audio.tasks.features_extraction.torchaudio import ( 265 extract_spectrogram_from_audios, 266 ) 267 268 spectrogram = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"] 269 y_axis_label = "Frequency [Hz]" 270 271 # ---- Guard against invalid/short-audio outputs (must be exactly this phrase) 272 if not torch.is_tensor(spectrogram): 273 raise ValueError("Spectrogram extraction failed") 274 if spectrogram.ndim == 0 or spectrogram.numel() == 0: 275 raise ValueError("Spectrogram extraction failed") 276 if spectrogram.dtype.is_floating_point and torch.isnan(spectrogram).any(): 277 raise ValueError("Spectrogram extraction failed") 278 279 if spectrogram.dim() != 2: 280 raise ValueError( 281 "Spectrogram must be a 2D tensor. Got shape: {}".format(spectrogram.shape), 282 "Please make sure the input audio is mono.", 283 ) 284 285 # Determine time and frequency scale 286 # num_frames = spectrogram.size(1) 287 num_freq_bins = spectrogram.size(0) 288 289 # Time axis in seconds 290 duration_sec = audio.waveform.size(-1) / audio.sampling_rate 291 time_axis_start = 0.0 292 time_axis_end = float(duration_sec) 293 294 # Frequency axis 295 if mel_scale: 296 freq_start, freq_end = 0.0, float(num_freq_bins - 1) 297 else: 298 freq_start, freq_end = 0.0, float(audio.sampling_rate / 2) 299 300 scale = _resolve_scale(context) 301 rc = _rc_for_scale(scale) 302 if figsize is None: 303 base = (10.0, 4.0) 304 else: 305 base = figsize 306 scaled_size = (base[0] * scale, base[1] * scale) 307 308 with rc_context(rc): 309 fig = plt.figure(figsize=scaled_size) 310 plt.imshow( 311 _power_to_db(spectrogram.cpu().numpy()), 312 aspect="auto", 313 origin="lower", 314 extent=(time_axis_start, time_axis_end, freq_start, freq_end), 315 cmap="viridis", 316 ) 317 plt.colorbar(label="Magnitude (dB)") 318 plt.title(title) 319 plt.ylabel(y_axis_label) 320 plt.xlabel("Time [s]") 321 plt.tight_layout() 322 plt.show(block=False) 323 return fig 324 325 326def plot_waveform_and_specgram( 327 audio: Audio, 328 *, 329 title: str = "Waveform + Spectrogram", 330 mel_scale: bool = False, 331 fast_wave: bool = False, 332 context: "_Context" = "auto", 333 figsize: Tuple[float, float] | None = None, 334 **spect_kwargs: Any, # noqa: ANN401 # forwarded to spectrogram extraction 335) -> Figure: 336 """Stacked layout: waveform (top) and **mono** spectrogram (bottom). Returns the Figure. 337 338 The waveform can be drawn in a faster, lightly decimated mode for long signals. 339 Spectrogram extraction is delegated to senselab's torchaudio-based utilities 340 and requires mono input. 341 342 Args: 343 audio (Audio): 344 Input audio. **Spectrogram requires mono**; downmix multi-channel first. 345 title (str, optional): 346 Overall figure title. Defaults to `"Waveform + Spectrogram"`. 347 mel_scale (bool, optional): 348 If `True`, bottom panel is a mel spectrogram; otherwise linear frequency. Defaults to `False`. 349 fast_wave (bool, optional): 350 If `True`, waveform panel is downsampled for speed. Defaults to `False`. 351 context (_Context, optional): 352 Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float). 353 Defaults to `"auto"`. 354 figsize (tuple[float, float] | None, optional): 355 Base `(width, height)` in inches **before** context scaling. Defaults to a balanced height. 356 **spect_kwargs: 357 Forwarded to the underlying spectrogram extractor (e.g., `n_fft`, `hop_length`, `n_mels`). 358 359 Returns: 360 matplotlib.figure.Figure: The created figure (also displayed). 361 362 Raises: 363 ValueError: If audio is not mono, or spectrogram extraction fails. 364 365 Example: 366 >>> from pathlib import Path 367 >>> from senselab.audio.data_structures import Audio 368 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 369 >>> fig = plot_waveform_and_specgram( 370 ... a1, 371 ... mel_scale=True, 372 ... fast_wave=True, 373 ... context="large", 374 ... n_fft=1024, 375 ... hop_length=256, 376 ... n_mels=80, 377 ... ) 378 >>> # fig.savefig("wave_plus_mel.png") 379 """ 380 # ---- Core timing info from ORIGINAL (non-decimated) data 381 sr = audio.sampling_rate 382 orig_num_frames = int(audio.waveform.size(-1)) 383 duration_sec = orig_num_frames / sr 384 t0, t1 = 0.0, float(duration_sec) 385 386 # ---- Prepare waveform (optionally decimated for speed) 387 waveform = audio.waveform 388 if fast_wave: 389 waveform = waveform[..., ::10] # decimate samples 390 num_channels, num_frames = waveform.shape 391 time_axis = np.linspace(0.0, duration_sec, num_frames, endpoint=False) 392 393 # ---- Guardrail: spectrogram plotting requires mono input 394 if audio.waveform.shape[0] != 1: 395 raise ValueError("Only mono audio is supported for spectrogram plotting") 396 397 # ---- Spectrogram (2D tensor: [freq_bins, time_frames]) 398 if mel_scale: 399 from senselab.audio.tasks.features_extraction.torchaudio import ( 400 extract_mel_spectrogram_from_audios, 401 ) 402 403 spec = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"] 404 ylab = "Mel bins" 405 f0, f1 = 0.0, float(spec.size(0) - 1) if torch.is_tensor(spec) and spec.ndim >= 1 else (0.0, 0.0) 406 spec_title = "Mel Spectrogram" 407 else: 408 from senselab.audio.tasks.features_extraction.torchaudio import ( 409 extract_spectrogram_from_audios, 410 ) 411 412 spec = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"] 413 ylab = "Frequency [Hz]" 414 f0, f1 = 0.0, float(sr / 2) 415 spec_title = "Spectrogram" 416 417 # ---- Guardrails for short/invalid outputs (exact phrase expected by tests) 418 if not torch.is_tensor(spec): 419 raise ValueError("Spectrogram extraction failed") 420 if spec.ndim == 0 or spec.numel() == 0: 421 raise ValueError("Spectrogram extraction failed") 422 if spec.dtype.is_floating_point and torch.isnan(spec).any(): 423 raise ValueError("Spectrogram extraction failed") 424 425 # We require a 2D (F x T) spectrogram. Anything else → fail (don’t auto-pick channels). 426 if spec.ndim != 2: 427 raise ValueError("Spectrogram extraction failed") 428 429 # ---- Layout & context 430 scale = _resolve_scale(context) 431 rc = _rc_for_scale(scale) 432 if figsize is None: 433 base_h = max(2.0, 0.9 * num_channels) + 4.0 # waveform height + spectrogram 434 base = (12.0, base_h) 435 else: 436 base = figsize 437 size = (base[0] * scale, base[1] * scale) 438 439 with rc_context(rc): 440 fig, (ax_wav, ax_spec) = plt.subplots(2, 1, figsize=size, sharex=True, gridspec_kw={"height_ratios": [1, 2]}) 441 442 # ---- Waveform (top) 443 if num_channels == 1: 444 ax_wav.plot(time_axis, waveform[0].cpu().numpy()) 445 ax_wav.set_ylabel("Amp") 446 else: 447 for c in range(num_channels): 448 ax_wav.plot(time_axis, waveform[c].cpu().numpy(), alpha=0.9 if c == 0 else 0.7) 449 ax_wav.set_ylabel("Amp (multi-ch)") 450 ax_wav.grid(True, alpha=0.3) 451 ax_wav.set_title("Waveform") 452 453 # ---- Spectrogram (bottom) 454 im = ax_spec.imshow( 455 _power_to_db(spec.cpu().numpy()), 456 aspect="auto", 457 origin="lower", 458 extent=(t0, t1, f0, f1), 459 cmap="viridis", 460 ) 461 ax_spec.set_ylabel(ylab) 462 ax_spec.set_xlabel("Time [s]") 463 ax_spec.set_title(spec_title) 464 465 # Keep both axes aligned in time 466 ax_wav.set_xlim(t0, t1) 467 ax_spec.set_xlim(t0, t1) 468 469 # ---- Horizontal colorbar below the spectrogram 470 divider = make_axes_locatable(ax_spec) 471 cax = divider.append_axes("bottom", size="5%", pad=0.6) 472 cbar = fig.colorbar(im, cax=cax, orientation="horizontal") 473 cbar.set_label("Magnitude (dB)") 474 475 fig.suptitle(title) 476 fig.tight_layout(rect=(0, 0, 1, 0.96)) 477 plt.show(block=False) 478 return fig 479 480 481def play_audio(audio: Audio) -> None: 482 """Play an `Audio` object inline (Jupyter/IPython), supporting 1–2 channels. 483 484 Uses `IPython.display.Audio` to render audio widgets in notebooks. For more 485 than two channels, downmix first. 486 487 Args: 488 audio (Audio): 489 Input audio to play (mono or stereo). Sampling rate is preserved. 490 491 Raises: 492 ValueError: If the waveform has more than 2 channels. 493 494 Example: 495 >>> from pathlib import Path 496 >>> from senselab.audio.data_structures import Audio 497 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 498 >>> play_audio(a1) 499 """ 500 from IPython.display import Audio as DisplayAudio 501 from IPython.display import display 502 503 waveform = audio.waveform.cpu().numpy() 504 sample_rate = audio.sampling_rate 505 506 num_channels = waveform.shape[0] 507 if num_channels == 1: 508 display(DisplayAudio(waveform[0], rate=sample_rate)) 509 elif num_channels == 2: 510 display(DisplayAudio((waveform[0], waveform[1]), rate=sample_rate)) 511 else: 512 raise ValueError("Waveform with more than 2 channels is not supported.")
134def plot_waveform( 135 audio: Audio, 136 title: str = "Waveform", 137 fast: bool = False, 138 *, 139 context: _Context = "auto", 140 figsize: Tuple[float, float] | None = None, 141) -> Figure: 142 """Plot the time-domain waveform of an `Audio` object and return the Figure. 143 144 The plot is automatically scaled for readability using a *context* scale 145 (similar to seaborn). Use `fast=True` to lightly decimate the signal for 146 quicker rendering on very long waveforms. 147 148 Args: 149 audio (Audio): 150 Input audio containing `.waveform` (shape `[C, T]`) and `.sampling_rate`. 151 title (str, optional): 152 Figure title. Defaults to `"Waveform"`. 153 fast (bool, optional): 154 If `True`, plots a 10× downsampled view for speed. Defaults to `False`. 155 context (_Context, optional): 156 Size preset or numeric scale. Accepted values: 157 * `"auto"` (detect from screen), `"small"`, `"medium"`, `"large"`, 158 * or a float scale factor (e.g., `1.25`). Defaults to `"auto"`. 159 figsize (tuple[float, float] | None, optional): 160 Base `(width, height)` in inches **before** context scaling. 161 Defaults to `(12, 2×channels)`. 162 163 Returns: 164 matplotlib.figure.Figure: The created figure (also displayed). 165 166 Example: 167 >>> from pathlib import Path 168 >>> from senselab.audio.data_structures import Audio 169 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 170 >>> fig = plot_waveform(a1, title="Sample 1", fast=True, context="medium") 171 >>> # fig.savefig("waveform.png") # optional 172 """ 173 waveform = audio.waveform 174 sample_rate = audio.sampling_rate 175 176 if fast: 177 waveform = waveform[..., ::10] 178 179 num_channels, num_frames = waveform.shape 180 time_axis = torch.linspace(0, num_frames / sample_rate, num_frames) 181 182 scale = _resolve_scale(context) 183 rc = _rc_for_scale(scale) 184 if figsize is None: 185 base = (12.0, max(2.0 * num_channels, 2.5)) 186 else: 187 base = figsize 188 scaled_size = (base[0] * scale, base[1] * scale) 189 190 with rc_context(rc): 191 fig, axes = plt.subplots(num_channels, 1, figsize=scaled_size, sharex=True) 192 if num_channels == 1: 193 axes = [axes] # ensure iterable 194 for c, ax in enumerate(axes): 195 ax.plot(time_axis.numpy(), waveform[c].cpu().numpy()) 196 ax.set_ylabel(f"Ch {c + 1}") 197 ax.grid(True, alpha=0.3) 198 fig.suptitle(title) 199 axes[-1].set_xlabel("Time [s]") 200 fig.tight_layout(rect=(0, 0, 1, 0.96)) 201 plt.show(block=False) 202 return fig
Plot the time-domain waveform of an Audio object and return the Figure.
The plot is automatically scaled for readability using a context scale
(similar to seaborn). Use fast=True to lightly decimate the signal for
quicker rendering on very long waveforms.
Arguments:
- audio (Audio): Input audio containing
.waveform(shape[C, T]) and.sampling_rate. - title (str, optional): Figure title. Defaults to
"Waveform". - fast (bool, optional): If
True, plots a 10× downsampled view for speed. Defaults toFalse. - context (_Context, optional): Size preset or numeric scale. Accepted values:
"auto"(detect from screen),"small","medium","large",- or a float scale factor (e.g.,
1.25). Defaults to"auto".
- figsize (tuple[float, float] | None, optional): Base
(width, height)in inches before context scaling. Defaults to(12, 2×channels).
Returns:
matplotlib.figure.Figure: The created figure (also displayed).
Example:
>>> from pathlib import Path >>> from senselab.audio.data_structures import Audio >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) >>> fig = plot_waveform(a1, title="Sample 1", fast=True, context="medium") >>> # fig.savefig("waveform.png") # optional
205def plot_specgram( 206 audio: Audio, 207 mel_scale: bool = False, 208 title: str = "Spectrogram", 209 *, 210 context: _Context = "auto", 211 figsize: Tuple[float, float] | None = None, 212 **spect_kwargs: Any, # noqa: ANN401 213) -> Figure: 214 """Plot a (mel-)spectrogram for a **mono** `Audio` object and return the Figure. 215 216 Internally calls senselab's torchaudio-based extractors: 217 `extract_spectrogram_from_audios` or `extract_mel_spectrogram_from_audios`. 218 The function expects a 2D spectrogram `[freq_bins, time_frames]`; multi-channel 219 inputs should be downmixed beforehand. 220 221 Args: 222 audio (Audio): 223 Input **mono** audio. If multi-channel, downmix first. 224 mel_scale (bool, optional): 225 If `True`, plots a mel spectrogram; otherwise linear frequency. Defaults to `False`. 226 title (str, optional): 227 Figure title. Defaults to `"Spectrogram"`. 228 context (_Context, optional): 229 Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float). 230 Defaults to `"auto"`. 231 figsize (tuple[float, float] | None, optional): 232 Base `(width, height)` in inches **before** context scaling. Defaults to `(10, 4)`. 233 **spect_kwargs: 234 Passed to the underlying extractor (e.g., `n_fft=1024`, `hop_length=256`, 235 `n_mels=80`, `win_length=1024`, `f_min=0`, `f_max=None`). 236 237 Returns: 238 matplotlib.figure.Figure: The created figure (also displayed). 239 240 Raises: 241 ValueError: If spectrogram extraction fails, contains NaNs, or the result is not 2D. 242 243 Example (linear spectrogram): 244 >>> from pathlib import Path 245 >>> from senselab.audio.data_structures import Audio 246 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 247 >>> fig = plot_specgram(a1, mel_scale=False, n_fft=1024, hop_length=256) 248 >>> # fig.savefig("spec.png") 249 250 Example (mel spectrogram): 251 >>> from pathlib import Path 252 >>> from senselab.audio.data_structures import Audio 253 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 254 >>> fig = plot_specgram(a1, mel_scale=True, n_mels=80, n_fft=1024, hop_length=256) 255 """ 256 # Extract the spectrogram 257 if mel_scale: 258 from senselab.audio.tasks.features_extraction.torchaudio import ( 259 extract_mel_spectrogram_from_audios, 260 ) 261 262 spectrogram = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"] 263 y_axis_label = "Mel frequency (bins)" 264 else: 265 from senselab.audio.tasks.features_extraction.torchaudio import ( 266 extract_spectrogram_from_audios, 267 ) 268 269 spectrogram = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"] 270 y_axis_label = "Frequency [Hz]" 271 272 # ---- Guard against invalid/short-audio outputs (must be exactly this phrase) 273 if not torch.is_tensor(spectrogram): 274 raise ValueError("Spectrogram extraction failed") 275 if spectrogram.ndim == 0 or spectrogram.numel() == 0: 276 raise ValueError("Spectrogram extraction failed") 277 if spectrogram.dtype.is_floating_point and torch.isnan(spectrogram).any(): 278 raise ValueError("Spectrogram extraction failed") 279 280 if spectrogram.dim() != 2: 281 raise ValueError( 282 "Spectrogram must be a 2D tensor. Got shape: {}".format(spectrogram.shape), 283 "Please make sure the input audio is mono.", 284 ) 285 286 # Determine time and frequency scale 287 # num_frames = spectrogram.size(1) 288 num_freq_bins = spectrogram.size(0) 289 290 # Time axis in seconds 291 duration_sec = audio.waveform.size(-1) / audio.sampling_rate 292 time_axis_start = 0.0 293 time_axis_end = float(duration_sec) 294 295 # Frequency axis 296 if mel_scale: 297 freq_start, freq_end = 0.0, float(num_freq_bins - 1) 298 else: 299 freq_start, freq_end = 0.0, float(audio.sampling_rate / 2) 300 301 scale = _resolve_scale(context) 302 rc = _rc_for_scale(scale) 303 if figsize is None: 304 base = (10.0, 4.0) 305 else: 306 base = figsize 307 scaled_size = (base[0] * scale, base[1] * scale) 308 309 with rc_context(rc): 310 fig = plt.figure(figsize=scaled_size) 311 plt.imshow( 312 _power_to_db(spectrogram.cpu().numpy()), 313 aspect="auto", 314 origin="lower", 315 extent=(time_axis_start, time_axis_end, freq_start, freq_end), 316 cmap="viridis", 317 ) 318 plt.colorbar(label="Magnitude (dB)") 319 plt.title(title) 320 plt.ylabel(y_axis_label) 321 plt.xlabel("Time [s]") 322 plt.tight_layout() 323 plt.show(block=False) 324 return fig
Plot a (mel-)spectrogram for a mono Audio object and return the Figure.
Internally calls senselab's torchaudio-based extractors:
extract_spectrogram_from_audios or extract_mel_spectrogram_from_audios.
The function expects a 2D spectrogram [freq_bins, time_frames]; multi-channel
inputs should be downmixed beforehand.
Arguments:
- audio (Audio): Input mono audio. If multi-channel, downmix first.
- mel_scale (bool, optional): If
True, plots a mel spectrogram; otherwise linear frequency. Defaults toFalse. - title (str, optional): Figure title. Defaults to
"Spectrogram". - context (_Context, optional): Size preset or numeric scale (
"auto","small","medium","large", or float). Defaults to"auto". - figsize (tuple[float, float] | None, optional): Base
(width, height)in inches before context scaling. Defaults to(10, 4). - **spect_kwargs: Passed to the underlying extractor (e.g.,
n_fft=1024,hop_length=256,n_mels=80,win_length=1024,f_min=0,f_max=None).
Returns:
matplotlib.figure.Figure: The created figure (also displayed).
Raises:
- ValueError: If spectrogram extraction fails, contains NaNs, or the result is not 2D.
Example (linear spectrogram):
from pathlib import Path from senselab.audio.data_structures import Audio a1 = Audio(filepath=Path("sample1.wav").resolve()) fig = plot_specgram(a1, mel_scale=False, n_fft=1024, hop_length=256)
fig.savefig("spec.png")
Example (mel spectrogram):
from pathlib import Path from senselab.audio.data_structures import Audio a1 = Audio(filepath=Path("sample1.wav").resolve()) fig = plot_specgram(a1, mel_scale=True, n_mels=80, n_fft=1024, hop_length=256)
327def plot_waveform_and_specgram( 328 audio: Audio, 329 *, 330 title: str = "Waveform + Spectrogram", 331 mel_scale: bool = False, 332 fast_wave: bool = False, 333 context: "_Context" = "auto", 334 figsize: Tuple[float, float] | None = None, 335 **spect_kwargs: Any, # noqa: ANN401 # forwarded to spectrogram extraction 336) -> Figure: 337 """Stacked layout: waveform (top) and **mono** spectrogram (bottom). Returns the Figure. 338 339 The waveform can be drawn in a faster, lightly decimated mode for long signals. 340 Spectrogram extraction is delegated to senselab's torchaudio-based utilities 341 and requires mono input. 342 343 Args: 344 audio (Audio): 345 Input audio. **Spectrogram requires mono**; downmix multi-channel first. 346 title (str, optional): 347 Overall figure title. Defaults to `"Waveform + Spectrogram"`. 348 mel_scale (bool, optional): 349 If `True`, bottom panel is a mel spectrogram; otherwise linear frequency. Defaults to `False`. 350 fast_wave (bool, optional): 351 If `True`, waveform panel is downsampled for speed. Defaults to `False`. 352 context (_Context, optional): 353 Size preset or numeric scale (`"auto"`, `"small"`, `"medium"`, `"large"`, or float). 354 Defaults to `"auto"`. 355 figsize (tuple[float, float] | None, optional): 356 Base `(width, height)` in inches **before** context scaling. Defaults to a balanced height. 357 **spect_kwargs: 358 Forwarded to the underlying spectrogram extractor (e.g., `n_fft`, `hop_length`, `n_mels`). 359 360 Returns: 361 matplotlib.figure.Figure: The created figure (also displayed). 362 363 Raises: 364 ValueError: If audio is not mono, or spectrogram extraction fails. 365 366 Example: 367 >>> from pathlib import Path 368 >>> from senselab.audio.data_structures import Audio 369 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 370 >>> fig = plot_waveform_and_specgram( 371 ... a1, 372 ... mel_scale=True, 373 ... fast_wave=True, 374 ... context="large", 375 ... n_fft=1024, 376 ... hop_length=256, 377 ... n_mels=80, 378 ... ) 379 >>> # fig.savefig("wave_plus_mel.png") 380 """ 381 # ---- Core timing info from ORIGINAL (non-decimated) data 382 sr = audio.sampling_rate 383 orig_num_frames = int(audio.waveform.size(-1)) 384 duration_sec = orig_num_frames / sr 385 t0, t1 = 0.0, float(duration_sec) 386 387 # ---- Prepare waveform (optionally decimated for speed) 388 waveform = audio.waveform 389 if fast_wave: 390 waveform = waveform[..., ::10] # decimate samples 391 num_channels, num_frames = waveform.shape 392 time_axis = np.linspace(0.0, duration_sec, num_frames, endpoint=False) 393 394 # ---- Guardrail: spectrogram plotting requires mono input 395 if audio.waveform.shape[0] != 1: 396 raise ValueError("Only mono audio is supported for spectrogram plotting") 397 398 # ---- Spectrogram (2D tensor: [freq_bins, time_frames]) 399 if mel_scale: 400 from senselab.audio.tasks.features_extraction.torchaudio import ( 401 extract_mel_spectrogram_from_audios, 402 ) 403 404 spec = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"] 405 ylab = "Mel bins" 406 f0, f1 = 0.0, float(spec.size(0) - 1) if torch.is_tensor(spec) and spec.ndim >= 1 else (0.0, 0.0) 407 spec_title = "Mel Spectrogram" 408 else: 409 from senselab.audio.tasks.features_extraction.torchaudio import ( 410 extract_spectrogram_from_audios, 411 ) 412 413 spec = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"] 414 ylab = "Frequency [Hz]" 415 f0, f1 = 0.0, float(sr / 2) 416 spec_title = "Spectrogram" 417 418 # ---- Guardrails for short/invalid outputs (exact phrase expected by tests) 419 if not torch.is_tensor(spec): 420 raise ValueError("Spectrogram extraction failed") 421 if spec.ndim == 0 or spec.numel() == 0: 422 raise ValueError("Spectrogram extraction failed") 423 if spec.dtype.is_floating_point and torch.isnan(spec).any(): 424 raise ValueError("Spectrogram extraction failed") 425 426 # We require a 2D (F x T) spectrogram. Anything else → fail (don’t auto-pick channels). 427 if spec.ndim != 2: 428 raise ValueError("Spectrogram extraction failed") 429 430 # ---- Layout & context 431 scale = _resolve_scale(context) 432 rc = _rc_for_scale(scale) 433 if figsize is None: 434 base_h = max(2.0, 0.9 * num_channels) + 4.0 # waveform height + spectrogram 435 base = (12.0, base_h) 436 else: 437 base = figsize 438 size = (base[0] * scale, base[1] * scale) 439 440 with rc_context(rc): 441 fig, (ax_wav, ax_spec) = plt.subplots(2, 1, figsize=size, sharex=True, gridspec_kw={"height_ratios": [1, 2]}) 442 443 # ---- Waveform (top) 444 if num_channels == 1: 445 ax_wav.plot(time_axis, waveform[0].cpu().numpy()) 446 ax_wav.set_ylabel("Amp") 447 else: 448 for c in range(num_channels): 449 ax_wav.plot(time_axis, waveform[c].cpu().numpy(), alpha=0.9 if c == 0 else 0.7) 450 ax_wav.set_ylabel("Amp (multi-ch)") 451 ax_wav.grid(True, alpha=0.3) 452 ax_wav.set_title("Waveform") 453 454 # ---- Spectrogram (bottom) 455 im = ax_spec.imshow( 456 _power_to_db(spec.cpu().numpy()), 457 aspect="auto", 458 origin="lower", 459 extent=(t0, t1, f0, f1), 460 cmap="viridis", 461 ) 462 ax_spec.set_ylabel(ylab) 463 ax_spec.set_xlabel("Time [s]") 464 ax_spec.set_title(spec_title) 465 466 # Keep both axes aligned in time 467 ax_wav.set_xlim(t0, t1) 468 ax_spec.set_xlim(t0, t1) 469 470 # ---- Horizontal colorbar below the spectrogram 471 divider = make_axes_locatable(ax_spec) 472 cax = divider.append_axes("bottom", size="5%", pad=0.6) 473 cbar = fig.colorbar(im, cax=cax, orientation="horizontal") 474 cbar.set_label("Magnitude (dB)") 475 476 fig.suptitle(title) 477 fig.tight_layout(rect=(0, 0, 1, 0.96)) 478 plt.show(block=False) 479 return fig
Stacked layout: waveform (top) and mono spectrogram (bottom). Returns the Figure.
The waveform can be drawn in a faster, lightly decimated mode for long signals. Spectrogram extraction is delegated to senselab's torchaudio-based utilities and requires mono input.
Arguments:
- audio (Audio): Input audio. Spectrogram requires mono; downmix multi-channel first.
- title (str, optional): Overall figure title. Defaults to
"Waveform + Spectrogram". - mel_scale (bool, optional): If
True, bottom panel is a mel spectrogram; otherwise linear frequency. Defaults toFalse. - fast_wave (bool, optional): If
True, waveform panel is downsampled for speed. Defaults toFalse. - context (_Context, optional): Size preset or numeric scale (
"auto","small","medium","large", or float). Defaults to"auto". - figsize (tuple[float, float] | None, optional): Base
(width, height)in inches before context scaling. Defaults to a balanced height. - **spect_kwargs: Forwarded to the underlying spectrogram extractor (e.g.,
n_fft,hop_length,n_mels).
Returns:
matplotlib.figure.Figure: The created figure (also displayed).
Raises:
- ValueError: If audio is not mono, or spectrogram extraction fails.
Example:
>>> from pathlib import Path >>> from senselab.audio.data_structures import Audio >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) >>> fig = plot_waveform_and_specgram( ... a1, ... mel_scale=True, ... fast_wave=True, ... context="large", ... n_fft=1024, ... hop_length=256, ... n_mels=80, ... ) >>> # fig.savefig("wave_plus_mel.png")
482def play_audio(audio: Audio) -> None: 483 """Play an `Audio` object inline (Jupyter/IPython), supporting 1–2 channels. 484 485 Uses `IPython.display.Audio` to render audio widgets in notebooks. For more 486 than two channels, downmix first. 487 488 Args: 489 audio (Audio): 490 Input audio to play (mono or stereo). Sampling rate is preserved. 491 492 Raises: 493 ValueError: If the waveform has more than 2 channels. 494 495 Example: 496 >>> from pathlib import Path 497 >>> from senselab.audio.data_structures import Audio 498 >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) 499 >>> play_audio(a1) 500 """ 501 from IPython.display import Audio as DisplayAudio 502 from IPython.display import display 503 504 waveform = audio.waveform.cpu().numpy() 505 sample_rate = audio.sampling_rate 506 507 num_channels = waveform.shape[0] 508 if num_channels == 1: 509 display(DisplayAudio(waveform[0], rate=sample_rate)) 510 elif num_channels == 2: 511 display(DisplayAudio((waveform[0], waveform[1]), rate=sample_rate)) 512 else: 513 raise ValueError("Waveform with more than 2 channels is not supported.")
Play an Audio object inline (Jupyter/IPython), supporting 1–2 channels.
Uses IPython.display.Audio to render audio widgets in notebooks. For more
than two channels, downmix first.
Arguments:
- audio (Audio): Input audio to play (mono or stereo). Sampling rate is preserved.
Raises:
- ValueError: If the waveform has more than 2 channels.
Example:
>>> from pathlib import Path >>> from senselab.audio.data_structures import Audio >>> a1 = Audio(filepath=Path("sample1.wav").resolve()) >>> play_audio(a1)