Source code for diffsptk.modules.cqt
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
# ------------------------------------------------------------------------ #
# Copyright (c) 2013--2022, librosa development team. #
# #
# Permission to use, copy, modify, and/or distribute this software for any #
# purpose with or without fee is hereby granted, provided that the above #
# copyright notice and this permission notice appear in all copies. #
# #
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES #
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF #
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR #
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES #
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN #
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF #
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. #
# ------------------------------------------------------------------------ #
import numpy as np
import torch
import torchaudio
from torch import nn
from ..third_party.librosa import (
cqt_frequencies,
early_downsample_count,
et_relative_bw,
relative_bandwidth,
vqt_filter_fft,
wavelet_lengths,
)
from ..utils.private import Lambda, get_resample_params, numpy_to_torch
from .base import BaseNonFunctionalModule
from .stft import ShortTimeFourierTransform as STFT
[docs]
class ConstantQTransform(BaseNonFunctionalModule):
"""Perform constant-Q transform based on the librosa implementation.
Parameters
----------
frame_period : int >= 1
The frame period in samples, :math:`P`.
sample_rate : int >= 1
The sample rate in Hz.
f_min : float > 0
The minimum center frequency in Hz.
n_bin : int >= 1
The number of CQ-bins, :math:`K`.
n_bin_per_octave : int >= 1
The number of bins per octave, :math:`B`.
tuning : float
The tuning offset in fractions of a bin.
filter_scale : float > 0
The filter scale factor.
norm : float
The type of norm used in the basis function normalization.
sparsity : float in [0, 1)
The sparsification factor.
window : str
The window function for the basis.
scale : bool
If True, scale the CQT response by the length of the filter.
res_type : ['kaiser_best', 'kaiser_fast'] or None
The resampling type.
**kwargs : additional keyword arguments
See `torchaudio.transforms.Resample
<https://pytorch.org/audio/main/generated/torchaudio.transforms.Resample.html>`_.
"""
def __init__(
self,
frame_period: int,
sample_rate: int,
*,
f_min: float = 32.7,
n_bin: float = 84,
n_bin_per_octave: int = 12,
tuning: float = 0,
filter_scale: float = 1,
norm: float = 1,
sparsity: float = 1e-2,
window: str = "hann",
scale: bool = True,
res_type: str | None = "kaiser_best",
**kwargs,
) -> None:
super().__init__()
if frame_period <= 0:
raise ValueError("frame_period must be positive.")
K = n_bin
B = n_bin_per_octave
n_octave = int(np.ceil(K / B))
n_filter = min(B, K)
freqs = cqt_frequencies(
n_bins=K,
fmin=f_min,
bins_per_octave=B,
tuning=tuning,
)
if K == 1:
alpha = et_relative_bw(B)
else:
alpha = relative_bandwidth(freqs=freqs)
lengths, filter_cutoff = wavelet_lengths(
freqs=freqs,
sr=sample_rate,
window=window,
filter_scale=filter_scale,
alpha=alpha,
)
early_downsample = []
downsample_count = early_downsample_count(
sample_rate * 0.5, filter_cutoff, frame_period, n_octave
)
if res_type is not None:
kwargs.update(get_resample_params(res_type))
if 0 < downsample_count:
downsample_factor = 2**downsample_count
early_downsample.append(
torchaudio.transforms.Resample(
orig_freq=downsample_factor,
new_freq=1,
dtype=torch.get_default_dtype(),
**kwargs,
)
)
if scale:
downsample_scale = np.sqrt(downsample_factor)
else:
downsample_scale = downsample_factor
early_downsample.append(Lambda(lambda x: x * downsample_scale))
# Update frame period and sample rate.
frame_period //= downsample_factor
sample_rate /= downsample_factor
# Update lengths for scaling.
if scale:
lengths, _ = wavelet_lengths(
freqs=freqs,
sr=sample_rate,
window=window,
filter_scale=filter_scale,
alpha=alpha,
)
self.early_downsample = nn.Sequential(*early_downsample)
if scale:
cqt_scale = np.reciprocal(np.sqrt(lengths))
else:
cqt_scale = np.ones(K)
self.register_buffer("cqt_scale", numpy_to_torch(cqt_scale))
fp = [frame_period]
sr = [sample_rate]
for i in range(n_octave - 1):
if fp[i] % 2 == 0:
fp.append(fp[i] // 2)
sr.append(sr[i] * 0.5)
else:
fp.append(fp[i])
sr.append(sr[i])
transforms = []
resamplers = []
for i in range(n_octave):
sl = slice(-n_filter * (i + 1), None if i == 0 else (-n_filter * i))
fft_basis, fft_length, _ = vqt_filter_fft(
sr[i],
freqs[sl],
filter_scale,
norm,
sparsity,
window=window,
alpha=alpha[sl],
)
fft_basis *= np.sqrt(sample_rate / sr[i])
self.register_buffer(
f"fft_basis_{i}", numpy_to_torch(fft_basis.todense()).T
)
transforms.append(
STFT(
frame_length=fft_length,
frame_period=fp[i],
fft_length=fft_length,
center=True,
window="rectangular",
norm="none",
eps=0,
out_format="complex",
)
)
if fp[i] % 2 == 0:
resample_scale = np.sqrt(2)
resamplers.append(
nn.Sequential(
torchaudio.transforms.Resample(
orig_freq=2,
new_freq=1,
dtype=torch.get_default_dtype(),
**kwargs,
),
Lambda(lambda x: x * resample_scale),
)
)
else:
resamplers.append(Lambda(lambda x: x))
self.transforms = nn.ModuleList(transforms)
self.resamplers = nn.ModuleList(resamplers)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Compute constant-Q transform.
Parameters
----------
x : Tensor [shape=(..., T)]
The input waveform.
Returns
-------
out : Tensor [shape=(..., T/P, K)]
The CQT complex output.
Examples
--------
>>> x = diffsptk.sin(99)
>>> cqt = diffsptk.CQT(100, 8000, n_bin=4)
>>> c = cqt(x).abs()
>>> c
tensor([[1.1259, 1.2069, 1.3008, 1.3885]])
"""
x = self.early_downsample(x)
cs = []
for i in range(len(self.transforms)):
X = self.transforms[i](x)
W = getattr(self, f"fft_basis_{i}")
cs.append(torch.matmul(X, W))
if i != len(self.transforms) - 1:
x = self.resamplers[i](x)
c = self._trim_stack(len(self.cqt_scale), cs) * self.cqt_scale
return c
@staticmethod
def _trim_stack(n_bin: int, cqt_response: list[torch.Tensor]) -> torch.Tensor:
max_col = min(c.shape[-2] for c in cqt_response)
shape = list(cqt_response[0].shape)
shape[-2] = max_col
shape[-1] = n_bin
output = torch.empty(
shape, dtype=cqt_response[0].dtype, device=cqt_response[0].device
)
end = n_bin
for c in cqt_response:
n_octave = c.shape[-1]
if end < n_octave:
output[..., :end] = c[..., :max_col, -end:]
else:
output[..., end - n_octave : end] = c[..., :max_col, :]
end -= n_octave
return output