Source code for diffsptk.modules.mgcep

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from ..misc.utils import check_size
from ..misc.utils import hankel
from ..misc.utils import symmetric_toeplitz
from ..misc.utils import to
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
from .gnorm import GeneralizedCepstrumGainNormalization
from .ignorm import GeneralizedCepstrumInverseGainNormalization
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
from .mcep import MelCepstralAnalysis
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum


class CoefficientsFrequencyTransform(nn.Module):
    def __init__(self, in_order, out_order, alpha):
        super().__init__()

        beta = 1 - alpha * alpha
        L1 = in_order + 1
        L2 = out_order + 1

        # Make transform matrix.
        A = torch.zeros((L2, L1), dtype=torch.double)
        A[0, 0] = 1
        if 1 < L2 and 1 < L1:
            A[1, 1:] = alpha ** torch.arange(L1 - 1, dtype=torch.double) * beta
        for i in range(2, L2):
            i1 = i - 1
            for j in range(1, L1):
                j1 = j - 1
                A[i, j] = A[i1, j1] + alpha * (A[i, j1] - A[i1, j])

        self.register_buffer("A", to(A.T))

    def forward(self, x):
        y = torch.matmul(x, self.A)
        return y


class PTransform(nn.Module):
    def __init__(self, order, alpha):
        super().__init__()

        # Make transform matrix.
        A = torch.eye(order + 1, dtype=torch.double)
        A[:, 1:].fill_diagonal_(alpha)

        A[0, 0] -= alpha * alpha
        A[0, 1] += alpha
        A[-1, -1] += alpha

        self.register_buffer("A", to(A.T))

    def forward(self, p):
        p = torch.matmul(p, self.A)
        return p


class QTransform(nn.Module):
    def __init__(self, order, alpha):
        super().__init__()

        # Make transform matrix.
        A = torch.eye(order + 1, dtype=torch.double)
        A[1:].fill_diagonal_(alpha)

        A[1, 0] = 0
        A[1, 1] += alpha

        self.register_buffer("A", to(A.T))

    def forward(self, q):
        q = torch.matmul(q, self.A)
        return q


[docs] class MelGeneralizedCepstralAnalysis(nn.Module): """See `this page <https://sp-nitech.github.io/sptk/latest/main/mgcep.html>`_ for details. Note that the current implementation does not use the efficient Toeplitz-plus-Hankel system solver. Parameters ---------- cep_order : int >= 0 Order of mel-cepstrum, :math:`M`. fft_length : int >= 2M Number of FFT bins, :math:`L`. alpha : float in (-1, 1) Frequency warping factor, :math:`\\alpha`. gamma : float in [-1, 0] Gamma, :math:`\\gamma`. n_iter : int >= 0 Number of iterations. """ def __init__(self, cep_order, fft_length, alpha=0, gamma=0, n_iter=0): super().__init__() assert 0 <= cep_order <= fft_length // 2 assert gamma <= 0 assert 0 <= n_iter self.cep_order = cep_order self.fft_length = fft_length self.gamma = gamma self.n_iter = n_iter if gamma == 0: self.mcep = MelCepstralAnalysis(cep_order, fft_length, alpha, n_iter=n_iter) else: self.cfreqt = CoefficientsFrequencyTransform( cep_order, fft_length - 1, -alpha ) self.pfreqt = CoefficientsFrequencyTransform( fft_length - 1, 2 * cep_order, alpha ) self.rfreqt = CoefficientsFrequencyTransform( fft_length - 1, cep_order, alpha ) self.ptrans = PTransform(2 * cep_order, alpha) self.qtrans = QTransform(2 * cep_order, alpha) self.b2b = nn.Sequential( GeneralizedCepstrumInverseGainNormalization(cep_order, -1), MLSADigitalFilterCoefficientsToMelCepstrum(cep_order, alpha), MelGeneralizedCepstrumToMelGeneralizedCepstrum( cep_order, cep_order, in_gamma=-1, out_gamma=gamma ), MelCepstrumToMLSADigitalFilterCoefficients(cep_order, alpha), GeneralizedCepstrumGainNormalization(cep_order, gamma), ) self.b2mc = nn.Sequential( GeneralizedCepstrumInverseGainNormalization(cep_order, gamma), MLSADigitalFilterCoefficientsToMelCepstrum(cep_order, alpha), )
[docs] def forward(self, x): """Estimate mel-generalized cepstrum from spectrum. Parameters ---------- x : Tensor [shape=(..., L/2+1)] Power spectrum. Returns ------- out : Tensor [shape=(..., M+1)] Mel-generalized cepstrum. Examples -------- >>> x = diffsptk.ramp(19) >>> stft = diffsptk.STFT(frame_length=10, frame_period=10, fft_length=16) >>> mgcep = diffsptk.MelGeneralizedCepstralAnalysis(3, 16, 0.1, n_iter=1) >>> mc = mgcep(stft(x)) >>> mc tensor([[-0.8851, 0.7917, -0.1737, 0.0175], [-0.3522, 4.4222, -1.0882, -0.0511]]) """ if self.gamma == 0: mc = self.mcep(x) return mc M = self.cep_order H = self.fft_length // 2 check_size(x.size(-1), H + 1, "dimension of spectrum") def newton(gamma, b1): def epsilon(gamma, r, b): eps = r[..., 0] + gamma * (r[..., 1:] * b).sum(-1) return eps b0 = torch.zeros(*b1.shape[:-1], 1, device=b1.device) b = torch.cat((b0, b1), dim=-1) c = self.cfreqt(b) C = torch.fft.rfft(c, n=self.fft_length) if gamma == -1: p_re = x else: X = 1 + gamma * C.real Y = gamma * C.imag XX = X * X YY = Y * Y D = XX + YY E = torch.pow(D, -1 / gamma) p = x * E / D p_re = p q = p / D q_re = q * (XX - YY) q_im = q * (2 * X * Y) r_re = p * X r_im = p * Y p = self.pfreqt(torch.fft.irfft(p_re)) if gamma == -1: q = p r = p[..., : M + 1] else: q = self.pfreqt(torch.fft.irfft(torch.complex(q_re, q_im))) r = self.rfreqt(torch.fft.irfft(torch.complex(r_re, r_im))) p = self.ptrans(p) q = self.qtrans(q) if gamma != -1: eps = epsilon(gamma, r, b1) pt = p[..., :M] qt = q[..., 2:] * (1 + gamma) rt = r[..., 1:] R = symmetric_toeplitz(pt) Q = hankel(qt) gradient = torch.linalg.solve(R + Q, rt) b1 = b1 + gradient if gamma == -1: eps = epsilon(gamma, r, b1) b0 = torch.sqrt(eps).unsqueeze(-1) return b0, b1 b1 = torch.zeros(*x.shape[:-1], M, device=x.device) b0, b1 = newton(-1, b1) if self.gamma != -1: b = torch.cat((b0, b1), dim=-1) b = self.b2b(b) _, b1 = torch.split(b, [1, M], dim=-1) for _ in range(self.n_iter): b0, b1 = newton(self.gamma, b1) b = torch.cat((b0, b1), dim=-1) mc = self.b2mc(b) return mc