Source code for diffsptk.core.mgcep

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import numpy as np
import torch
import torch.nn as nn

from ..misc.utils import check_size
from ..misc.utils import hankel
from ..misc.utils import is_power_of_two
from ..misc.utils import numpy_to_torch
from ..misc.utils import symmetric_toeplitz
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
from .gnorm import GeneralizedCepstrumGainNormalization
from .ignorm import GeneralizedCepstrumInverseGainNormalization
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
from .mcep import MelCepstralAnalysis
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum


class CoefficientsFrequencyTransform(nn.Module):
    def __init__(self, in_order, out_order, alpha):
        super(CoefficientsFrequencyTransform, self).__init__()

        beta = 1 - alpha * alpha
        L1 = in_order + 1
        L2 = out_order + 1

        # Make transform matrix.
        A = np.zeros((L2, L1))
        A[0, 0] = 1
        if 1 < L2 and 1 < L1:
            A[1, 1:] = alpha ** np.arange(L1 - 1) * beta
        for i in range(2, L2):
            i1 = i - 1
            for j in range(1, L1):
                j1 = j - 1
                A[i, j] = A[i1, j1] + alpha * (A[i, j1] - A[i1, j])

        self.register_buffer("A", numpy_to_torch(A.T))

    def forward(self, x):
        y = torch.matmul(x, self.A)
        return y


class PTransform(nn.Module):
    def __init__(self, order, alpha):
        super(PTransform, self).__init__()

        # Make transform matrix.
        A = np.eye(order + 1)
        np.fill_diagonal(A[:, 1:], alpha)

        A[0, 0] -= alpha * alpha
        A[0, 1] += alpha
        A[-1, -1] += alpha

        self.register_buffer("A", numpy_to_torch(A.T))

    def forward(self, p):
        p = torch.matmul(p, self.A)
        return p


class QTransform(nn.Module):
    def __init__(self, order, alpha):
        super(QTransform, self).__init__()

        # Make transform matrix.
        A = np.eye(order + 1)
        np.fill_diagonal(A[1:], alpha)

        A[1, 0] = 0
        A[1, 1] += alpha

        self.register_buffer("A", numpy_to_torch(A.T))

    def forward(self, q):
        q = torch.matmul(q, self.A)
        return q


[docs]class MelGeneralizedCepstralAnalysis(nn.Module): """See `this page <https://sp-nitech.github.io/sptk/latest/main/mgcep.html>`_ for details. Note that the current implementation does not use the efficient Toeplitz-plus-Hankel system solver. Parameters ---------- cep_order : int >= 0 [scalar] Order of mel-cepstrum, :math:`M`. fft_length : int >= 2M [scalar] Number of FFT bins, :math:`L`. alpha : float [-1 < alpha < 1] Frequency warping factor, :math:`\\alpha`. gamma : float [-1 <= gamma <= 0] Gamma, :math:`\\gamma`. n_iter : int >= 0 [scalar] Number of iterations. """ def __init__(self, cep_order, fft_length, alpha=0, gamma=0, n_iter=0): super(MelGeneralizedCepstralAnalysis, self).__init__() self.cep_order = cep_order self.fft_length = fft_length self.gamma = gamma self.n_iter = n_iter assert 0 <= self.cep_order assert self.cep_order <= self.fft_length // 2 assert is_power_of_two(self.fft_length) assert self.gamma <= 0 assert 0 <= self.n_iter if gamma == 0: self.mcep = MelCepstralAnalysis(cep_order, fft_length, alpha, n_iter=n_iter) else: self.cfreqt = CoefficientsFrequencyTransform( cep_order, fft_length - 1, -alpha ) self.pfreqt = CoefficientsFrequencyTransform( fft_length - 1, 2 * cep_order, alpha ) self.rfreqt = CoefficientsFrequencyTransform( fft_length - 1, cep_order, alpha ) self.ptrans = PTransform(2 * cep_order, alpha) self.qtrans = QTransform(2 * cep_order, alpha) self.b2b = nn.Sequential( GeneralizedCepstrumInverseGainNormalization(cep_order, -1), MLSADigitalFilterCoefficientsToMelCepstrum(cep_order, alpha), MelGeneralizedCepstrumToMelGeneralizedCepstrum( cep_order, cep_order, in_gamma=-1, out_gamma=gamma ), MelCepstrumToMLSADigitalFilterCoefficients(cep_order, alpha), GeneralizedCepstrumGainNormalization(cep_order, gamma), ) self.b2mc = nn.Sequential( GeneralizedCepstrumInverseGainNormalization(cep_order, gamma), MLSADigitalFilterCoefficientsToMelCepstrum(cep_order, alpha), )
[docs] def forward(self, x): """Estimate mel-generalized cepstrum from spectrum. Parameters ---------- x : Tensor [shape=(..., L/2+1)] Power spectrum. Returns ------- mc : Tensor [shape=(..., M+1)] Mel-generalized cepstrum. Examples -------- >>> x = diffsptk.ramp(19) >>> stft = diffsptk.STFT(frame_length=10, frame_period=10, fft_length=16) >>> mgcep = diffsptk.MelGeneralizedCepstralAnalysis(3, 16, 0.1, n_iter=1) >>> mc = mgcep(stft(x)) >>> mc tensor([[-0.8851, 0.7917, -0.1737, 0.0175], [-0.3522, 4.4222, -1.0882, -0.0511]]) """ if self.gamma == 0: mc = self.mcep(x) return mc M = self.cep_order H = self.fft_length // 2 check_size(x.size(-1), H + 1, "dimension of spectrum") def newton(gamma, b1): def epsilon(gamma, r, b): eps = r[..., 0] + gamma * (r[..., 1:] * b).sum(-1) return eps b0 = torch.zeros(*b1.shape[:-1], 1, device=b1.device) b = torch.cat((b0, b1), dim=-1) c = self.cfreqt(b) C = torch.fft.rfft(c, n=self.fft_length) if gamma == -1: p_re = x else: X = 1 + gamma * C.real Y = gamma * C.imag XX = X * X YY = Y * Y D = XX + YY E = torch.pow(D, -1 / gamma) p = x * E / D p_re = p q = p / D q_re = q * (XX - YY) q_im = q * (2 * X * Y) r_re = p * X r_im = p * Y p = self.pfreqt(torch.fft.irfft(p_re)) if gamma == -1: q = p r = p[..., : M + 1] else: q = self.pfreqt(torch.fft.irfft(torch.complex(q_re, q_im))) r = self.rfreqt(torch.fft.irfft(torch.complex(r_re, r_im))) p = self.ptrans(p) q = self.qtrans(q) if gamma != -1: eps = epsilon(gamma, r, b1) pt = p[..., :M] qt = q[..., 2:] * (1 + gamma) rt = r[..., 1:] R = symmetric_toeplitz(pt) Q = hankel(qt) gradient = torch.linalg.solve(R + Q, rt) b1 = b1 + gradient if gamma == -1: eps = epsilon(gamma, r, b1) b0 = torch.sqrt(eps).unsqueeze(-1) return b0, b1 b1 = torch.zeros(*x.shape[:-1], M, device=x.device) b0, b1 = newton(-1, b1) if self.gamma != -1: b = torch.cat((b0, b1), dim=-1) b = self.b2b(b) _, b1 = torch.split(b, [1, M], dim=-1) for _ in range(self.n_iter): b0, b1 = newton(self.gamma, b1) b = torch.cat((b0, b1), dim=-1) mc = self.b2mc(b) return mc