Source code for diffsptk.modules.mgcep

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from ..utils.private import check_size, get_gamma, hankel, symmetric_toeplitz, to
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
from .base import BaseNonFunctionalModule
from .gnorm import GeneralizedCepstrumGainNormalization
from .ignorm import GeneralizedCepstrumInverseGainNormalization
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
from .mcep import MelCepstralAnalysis
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum


[docs] class MelGeneralizedCepstralAnalysis(BaseNonFunctionalModule): """See `this page <https://sp-nitech.github.io/sptk/latest/main/mgcep.html>`_ for details. Note that the current implementation does not use the efficient Toeplitz-plus-Hankel system solver. Parameters ---------- fft_length : int >= 2M The number of FFT bins, :math:`L`. cep_order : int >= 0 The order of the mel-cepstrum, :math:`M`. alpha : float in (-1, 1) The frequency warping factor, :math:`\\alpha`. gamma : float in [-1, 0] The gamma parameter, :math:`\\gamma`. c : int >= 1 or None The number of stages. n_iter : int >= 0 THe number of iterations. device : torch.device or None The device of this module. dtype : torch.dtype or None The data type of this module. """ def __init__( self, *, fft_length: int, cep_order: int, alpha: float = 0, gamma: float = 0, c: int | None = None, n_iter: int = 0, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() gamma = get_gamma(gamma, c) if fft_length <= 1: raise ValueError("fft_length must be greater than 1.") if cep_order < 0: raise ValueError("cep_order must be non-negative.") if fft_length < 2 * cep_order: raise ValueError("cep_order must be less than or equal to fft_length // 2.") if 1 <= abs(alpha): raise ValueError("alpha must be in (-1, 1).") if gamma < -1 or 0 < gamma: raise ValueError("gamma must be in [-1, 0].") if n_iter < 0: raise ValueError("n_iter must be non-negative.") self.fft_length = fft_length self.cep_order = cep_order self.gamma = gamma self.n_iter = n_iter if gamma == 0: self.mcep = MelCepstralAnalysis( fft_length=fft_length, cep_order=cep_order, alpha=alpha, n_iter=n_iter, device=device, dtype=dtype, ) else: self.cfreqt = CoefficientsFrequencyTransform( cep_order, fft_length - 1, -alpha, device=device, dtype=dtype ) self.pfreqt = CoefficientsFrequencyTransform( fft_length - 1, 2 * cep_order, alpha, device=device, dtype=dtype ) self.rfreqt = CoefficientsFrequencyTransform( fft_length - 1, cep_order, alpha, device=device, dtype=dtype ) self.ptrans = PTransform(2 * cep_order, alpha, device=device, dtype=dtype) self.qtrans = QTransform(2 * cep_order, alpha, device=device, dtype=dtype) self.b2b = nn.Sequential( GeneralizedCepstrumInverseGainNormalization(cep_order, -1), MLSADigitalFilterCoefficientsToMelCepstrum( cep_order, alpha, device=device, dtype=dtype ), MelGeneralizedCepstrumToMelGeneralizedCepstrum( cep_order, cep_order, in_gamma=-1, out_gamma=gamma, device=device, dtype=dtype, ), MelCepstrumToMLSADigitalFilterCoefficients( cep_order, alpha, device=device, dtype=dtype ), GeneralizedCepstrumGainNormalization(cep_order, gamma), ) self.b2mc = nn.Sequential( GeneralizedCepstrumInverseGainNormalization(cep_order, gamma), MLSADigitalFilterCoefficientsToMelCepstrum( cep_order, alpha, device=device, dtype=dtype ), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Estimate mel-generalized cepstrum from spectrum. Parameters ---------- x : Tensor [shape=(..., L/2+1)] The power spectrum. Returns ------- out : Tensor [shape=(..., M+1)] The mel-generalized cepstrum. Examples -------- >>> import diffsptk >>> stft = diffsptk.STFT(frame_length=10, frame_period=10, fft_length=16) >>> mgcep = diffsptk.MelGeneralizedCepstralAnalysis( ... fft_length=16, cep_order=3, alpha=0.1, gamma=0.0, n_iter=1 ... ) >>> x = diffsptk.ramp(19) >>> mc = mgcep(stft(x)) >>> mc tensor([[-0.8851, 0.7917, -0.1737, 0.0175], [-0.3522, 4.4222, -1.0882, -0.0510]]) """ if self.gamma == 0: mc = self.mcep(x) return mc M = self.cep_order H = self.fft_length // 2 check_size(x.size(-1), H + 1, "dimension of spectrum") def newton(gamma, b1): def epsilon(gamma, r, b): eps = r[..., 0] + gamma * (r[..., 1:] * b).sum(-1) return eps b0 = torch.zeros(*b1.shape[:-1], 1, device=b1.device, dtype=b1.dtype) b = torch.cat((b0, b1), dim=-1) c = self.cfreqt(b) C = torch.fft.rfft(c, n=self.fft_length) if gamma == -1: p_re = x else: X = 1 + gamma * C.real Y = gamma * C.imag XX = X * X YY = Y * Y D = XX + YY E = torch.pow(D, -1 / gamma) p = x * E / D p_re = p q = p / D q_re = q * (XX - YY) q_im = q * (2 * X * Y) r_re = p * X r_im = p * Y p = self.pfreqt(torch.fft.irfft(p_re)) if gamma == -1: q = p r = p[..., : M + 1] else: q = self.pfreqt(torch.fft.irfft(torch.complex(q_re, q_im))) r = self.rfreqt(torch.fft.irfft(torch.complex(r_re, r_im))) p = self.ptrans(p) q = self.qtrans(q) if gamma != -1: eps = epsilon(gamma, r, b1) pt = p[..., :M] qt = q[..., 2:] * (1 + gamma) rt = r[..., 1:] R = symmetric_toeplitz(pt) Q = hankel(qt) gradient = torch.linalg.solve(R + Q, rt) b1 = b1 + gradient if gamma == -1: eps = epsilon(gamma, r, b1) b0 = torch.sqrt(eps).unsqueeze(-1) return b0, b1 b1 = torch.zeros(*x.shape[:-1], M, device=x.device, dtype=x.dtype) b0, b1 = newton(-1, b1) if self.gamma != -1: b = torch.cat((b0, b1), dim=-1) b = self.b2b(b) _, b1 = torch.split(b, [1, M], dim=-1) for _ in range(self.n_iter): b0, b1 = newton(self.gamma, b1) b = torch.cat((b0, b1), dim=-1) mc = self.b2mc(b) return mc
class CoefficientsFrequencyTransform(nn.Module): def __init__( self, in_order: int, out_order: int, alpha: float, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() beta = 1 - alpha * alpha L1 = in_order + 1 L2 = out_order + 1 # Make transform matrix. A = torch.zeros((L2, L1), device=device, dtype=torch.double) A[0, 0] = 1 if 1 < L2 and 1 < L1: A[1, 1:] = alpha ** torch.arange(L1 - 1, dtype=torch.double) * beta for i in range(2, L2): i1 = i - 1 for j in range(1, L1): j1 = j - 1 A[i, j] = A[i1, j1] + alpha * (A[i, j1] - A[i1, j]) self.register_buffer("A", to(A.T, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: y = torch.matmul(x, self.A) return y class PTransform(nn.Module): def __init__( self, order: int, alpha: float, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() # Make transform matrix. A = torch.eye(order + 1, device=device, dtype=torch.double) A[:, 1:].fill_diagonal_(alpha) A[0, 0] -= alpha * alpha A[0, 1] += alpha A[-1, -1] += alpha self.register_buffer("A", to(A.T, dtype=dtype)) def forward(self, p: torch.Tensor) -> torch.Tensor: p = torch.matmul(p, self.A) return p class QTransform(nn.Module): def __init__( self, order: int, alpha: float, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() # Make transform matrix. A = torch.eye(order + 1, device=device, dtype=torch.double) A[1:].fill_diagonal_(alpha) A[1, 0] = 0 A[1, 1] += alpha self.register_buffer("A", to(A.T, dtype=dtype)) def forward(self, q: torch.Tensor) -> torch.Tensor: q = torch.matmul(q, self.A) return q