# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #
import numpy as np
import torch
import torch.nn as nn
from ..misc.utils import check_size
from ..misc.utils import hankel
from ..misc.utils import is_power_of_two
from ..misc.utils import numpy_to_torch
from ..misc.utils import symmetric_toeplitz
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
from .gnorm import GeneralizedCepstrumGainNormalization
from .ignorm import GeneralizedCepstrumInverseGainNormalization
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
from .mcep import MelCepstralAnalysis
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum
class CoefficientsFrequencyTransform(nn.Module):
    def __init__(self, in_order, out_order, alpha):
        super(CoefficientsFrequencyTransform, self).__init__()
        beta = 1 - alpha * alpha
        L1 = in_order + 1
        L2 = out_order + 1
        # Make transform matrix.
        A = np.zeros((L2, L1))
        A[0, 0] = 1
        if 1 < L2 and 1 < L1:
            A[1, 1:] = alpha ** np.arange(L1 - 1) * beta
        for i in range(2, L2):
            i1 = i - 1
            for j in range(1, L1):
                j1 = j - 1
                A[i, j] = A[i1, j1] + alpha * (A[i, j1] - A[i1, j])
        self.register_buffer("A", numpy_to_torch(A.T))
    def forward(self, x):
        y = torch.matmul(x, self.A)
        return y
class PTransform(nn.Module):
    def __init__(self, order, alpha):
        super(PTransform, self).__init__()
        # Make transform matrix.
        A = np.eye(order + 1)
        np.fill_diagonal(A[:, 1:], alpha)
        A[0, 0] -= alpha * alpha
        A[0, 1] += alpha
        A[-1, -1] += alpha
        self.register_buffer("A", numpy_to_torch(A.T))
    def forward(self, p):
        p = torch.matmul(p, self.A)
        return p
class QTransform(nn.Module):
    def __init__(self, order, alpha):
        super(QTransform, self).__init__()
        # Make transform matrix.
        A = np.eye(order + 1)
        np.fill_diagonal(A[1:], alpha)
        A[1, 0] = 0
        A[1, 1] += alpha
        self.register_buffer("A", numpy_to_torch(A.T))
    def forward(self, q):
        q = torch.matmul(q, self.A)
        return q
[docs]class MelGeneralizedCepstralAnalysis(nn.Module):
    """See `this page <https://sp-nitech.github.io/sptk/latest/main/mgcep.html>`_
    for details. Note that the current implementation does not use the efficient
    Toeplitz-plus-Hankel system solver.
    Parameters
    ----------
    cep_order : int >= 0 [scalar]
        Order of mel-cepstrum, :math:`M`.
    fft_length : int >= 2M [scalar]
        Number of FFT bins, :math:`L`.
    alpha : float [-1 < alpha < 1]
        Frequency warping factor, :math:`\\alpha`.
    gamma : float [-1 <= gamma <= 0]
        Gamma, :math:`\\gamma`.
    n_iter : int >= 0 [scalar]
        Number of iterations.
    """
    def __init__(self, cep_order, fft_length, alpha=0, gamma=0, n_iter=0):
        super(MelGeneralizedCepstralAnalysis, self).__init__()
        self.cep_order = cep_order
        self.fft_length = fft_length
        self.gamma = gamma
        self.n_iter = n_iter
        assert 0 <= self.cep_order
        assert self.cep_order <= self.fft_length // 2
        assert is_power_of_two(self.fft_length)
        assert self.gamma <= 0
        assert 0 <= self.n_iter
        if gamma == 0:
            self.mcep = MelCepstralAnalysis(cep_order, fft_length, alpha, n_iter=n_iter)
        else:
            self.cfreqt = CoefficientsFrequencyTransform(
                cep_order, fft_length - 1, -alpha
            )
            self.pfreqt = CoefficientsFrequencyTransform(
                fft_length - 1, 2 * cep_order, alpha
            )
            self.rfreqt = CoefficientsFrequencyTransform(
                fft_length - 1, cep_order, alpha
            )
            self.ptrans = PTransform(2 * cep_order, alpha)
            self.qtrans = QTransform(2 * cep_order, alpha)
            self.b2b = nn.Sequential(
                GeneralizedCepstrumInverseGainNormalization(cep_order, -1),
                MLSADigitalFilterCoefficientsToMelCepstrum(cep_order, alpha),
                MelGeneralizedCepstrumToMelGeneralizedCepstrum(
                    cep_order, cep_order, in_gamma=-1, out_gamma=gamma
                ),
                MelCepstrumToMLSADigitalFilterCoefficients(cep_order, alpha),
                GeneralizedCepstrumGainNormalization(cep_order, gamma),
            )
            self.b2mc = nn.Sequential(
                GeneralizedCepstrumInverseGainNormalization(cep_order, gamma),
                MLSADigitalFilterCoefficientsToMelCepstrum(cep_order, alpha),
            )
[docs]    def forward(self, x):
        """Estimate mel-generalized cepstrum from spectrum.
        Parameters
        ----------
        x : Tensor [shape=(..., L/2+1)]
            Power spectrum.
        Returns
        -------
        mc : Tensor [shape=(..., M+1)]
            Mel-generalized cepstrum.
        Examples
        --------
        >>> x = diffsptk.ramp(19)
        >>> stft = diffsptk.STFT(frame_length=10, frame_period=10, fft_length=16)
        >>> mgcep = diffsptk.MelGeneralizedCepstralAnalysis(3, 16, 0.1, n_iter=1)
        >>> mc = mgcep(stft(x))
        >>> mc
        tensor([[-0.8851,  0.7917, -0.1737,  0.0175],
                [-0.3522,  4.4222, -1.0882, -0.0511]])
        """
        if self.gamma == 0:
            mc = self.mcep(x)
            return mc
        M = self.cep_order
        H = self.fft_length // 2
        check_size(x.size(-1), H + 1, "dimension of spectrum")
        def newton(gamma, b1):
            def epsilon(gamma, r, b):
                eps = r[..., 0] + gamma * (r[..., 1:] * b).sum(-1)
                return eps
            b0 = torch.zeros(*b1.shape[:-1], 1, device=b1.device)
            b = torch.cat((b0, b1), dim=-1)
            c = self.cfreqt(b)
            C = torch.fft.rfft(c, n=self.fft_length)
            if gamma == -1:
                p_re = x
            else:
                X = 1 + gamma * C.real
                Y = gamma * C.imag
                XX = X * X
                YY = Y * Y
                D = XX + YY
                E = torch.pow(D, -1 / gamma)
                p = x * E / D
                p_re = p
                q = p / D
                q_re = q * (XX - YY)
                q_im = q * (2 * X * Y)
                r_re = p * X
                r_im = p * Y
            p = self.pfreqt(torch.fft.irfft(p_re))
            if gamma == -1:
                q = p
                r = p[..., : M + 1]
            else:
                q = self.pfreqt(torch.fft.irfft(torch.complex(q_re, q_im)))
                r = self.rfreqt(torch.fft.irfft(torch.complex(r_re, r_im)))
            p = self.ptrans(p)
            q = self.qtrans(q)
            if gamma != -1:
                eps = epsilon(gamma, r, b1)
            pt = p[..., :M]
            qt = q[..., 2:] * (1 + gamma)
            rt = r[..., 1:]
            R = symmetric_toeplitz(pt)
            Q = hankel(qt)
            gradient = torch.linalg.solve(R + Q, rt)
            b1 = b1 + gradient
            if gamma == -1:
                eps = epsilon(gamma, r, b1)
            b0 = torch.sqrt(eps).unsqueeze(-1)
            return b0, b1
        b1 = torch.zeros(*x.shape[:-1], M, device=x.device)
        b0, b1 = newton(-1, b1)
        if self.gamma != -1:
            b = torch.cat((b0, b1), dim=-1)
            b = self.b2b(b)
            _, b1 = torch.split(b, [1, M], dim=-1)
            for _ in range(self.n_iter):
                b0, b1 = newton(self.gamma, b1)
        b = torch.cat((b0, b1), dim=-1)
        mc = self.b2mc(b)
        return mc