# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
import torch
import torch.nn.functional as F
from torch import nn
from ..utils.private import Lambda
from ..utils.private import check_size
from ..utils.private import get_gamma
from ..utils.private import remove_gain
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
from .base import BaseNonFunctionalModule
from .c2mpir import CepstrumToMinimumPhaseImpulseResponse
from .gnorm import GeneralizedCepstrumGainNormalization
from .istft import InverseShortTimeFourierTransform
from .linear_intpl import LinearInterpolation
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum
from .mgc2sp import MelGeneralizedCepstrumToSpectrum
from .stft import ShortTimeFourierTransform
def is_array_like(x):
return isinstance(x, (tuple, list))
def mirror(x, half=False):
x0, x1 = torch.split(x, [1, x.size(-1) - 1], dim=-1)
if half:
x1 = x1 * 0.5
y = torch.cat((x1.flip(-1), x0, x1), dim=-1)
return y
[docs]
class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/mglsadf.html>`_
for details.
Parameters
----------
filter_order : int >= 0 or tuple[int, int]
The order of the filter coefficients, :math:`M` or :math:`(N, M)`. A tuple input
is allowed only if **phase** is 'mixed'.
frame_period : int >= 1
The frame period, :math:`P`.
alpha : float in (-1, 1)
The frequency warping factor, :math:`\\alpha`.
gamma : float in [-1, 1]
The gamma parameter, :math:`\\gamma`.
c : int >= 1 or None
The number of stages.
ignore_gain : bool
If True, filtering is performed without gain.
phase : ['minimum', 'maximum', 'zero', 'mixed']
The filter type.
mode : ['multi-stage', 'single-stage', 'freq-domain']
'multi-stage' approximates the MLSA filter by cascading FIR filters based on the
Taylor series expansion. 'single-stage' uses an FIR filter with the coefficients
derived from the impulse response converted from the input mel-cepstral
coefficients using FFT. 'freq-domain' performs filtering in the frequency domain
rather than the time domain.
n_fft : int >= 1
The number of FFT bins used for conversion. Higher values result in increased
conversion accuracy.
taylor_order : int >= 0
The order of the Taylor series expansion (valid only if **mode** is
'multi-stage').
cep_order : int >= 0 or tuple[int, int]
The order of the linear cepstrum (valid only if **mode** is 'multi-stage').
ir_length : int >= 1 or tuple[int, int]
The length of the impulse response (valid only if **mode** is 'single-stage').
**kwargs : additional keyword arguments
See :func:`~diffsptk.ShortTimeFourierTransform` (valid only if **mode** is
'freq-domain').
References
----------
.. [1] T. Yoshimura et al., "Embedding a differentiable mel-cepstral synthesis
filter to a neural speech synthesis system," *Proceedings of ICASSP*, 2023.
"""
def __init__(
self,
filter_order,
frame_period,
*,
alpha=0,
gamma=0,
c=None,
ignore_gain=False,
phase="minimum",
mode="multi-stage",
**kwargs,
):
super().__init__()
self.frame_period = frame_period
# Format parameters.
if phase == "mixed" and not is_array_like(filter_order):
filter_order = (filter_order, filter_order)
gamma = get_gamma(gamma, c)
if phase == "mixed":
self.split_sections = (filter_order[0], filter_order[1] + 1)
else:
self.split_sections = (filter_order + 1,)
def flip(x):
if is_array_like(x):
return x[1], x[0]
return x
flip_keys = ("cep_order", "ir_length")
modified_kwargs = kwargs.copy()
for key in flip_keys:
if key in kwargs:
modified_kwargs[key] = flip(kwargs[key])
flipped_filter_order = flip(filter_order)
if mode == "multi-stage":
self.mglsadf = MultiStageFIRFilter(
flipped_filter_order,
frame_period,
alpha=alpha,
gamma=gamma,
ignore_gain=ignore_gain,
phase=phase,
**modified_kwargs,
)
elif mode == "single-stage":
self.mglsadf = SingleStageFIRFilter(
flipped_filter_order,
frame_period,
alpha=alpha,
gamma=gamma,
ignore_gain=ignore_gain,
phase=phase,
**modified_kwargs,
)
elif mode == "freq-domain":
self.mglsadf = FrequencyDomainFIRFilter(
flipped_filter_order,
frame_period,
alpha=alpha,
gamma=gamma,
ignore_gain=ignore_gain,
phase=phase,
**modified_kwargs,
)
else:
raise ValueError(f"mode {mode} is not supported.")
[docs]
def forward(self, x, mc):
"""Apply an MGLSA digital filter.
Parameters
----------
x : Tensor [shape=(..., T)]
The excitation signal.
mc : Tensor [shape=(..., T/P, M+1)] or [shape=(..., T/P, N+M+1)]
The mel-generalized cepstrum, not MLSA digital filter coefficients. Note
that the mixed-phase case assumes that the coefficients are of the form
c_{-N}, ..., c_{0}, ..., c_{M}, where M is the order of the minimum-phase
part and N is the order of the maximum-phase part.
Returns
-------
out : Tensor [shape=(..., T)]
The output signal.
Examples
--------
>>> M = 4
>>> x = diffsptk.step(3)
>>> mc = diffsptk.nrand(2, M)
>>> mc
tensor([[-0.9134, -0.5774, -0.4567, 0.7423, -0.5782],
[ 0.6904, 0.5175, 0.8765, 0.1677, 2.4624]])
>>> mglsadf = diffsptk.MLSA(M, frame_period=2)
>>> y = mglsadf(x.view(1, -1), mc.view(1, 2, M + 1))
>>> y
tensor([[0.4011, 0.8760, 3.5677, 4.8725]])
"""
check_size(mc.size(-1), sum(self.split_sections), "dimension of mel-cepstrum")
check_size(x.size(-1), mc.size(-2) * self.frame_period, "sequence length")
if len(self.split_sections) != 1:
mc_max, mc_min = torch.split(mc, self.split_sections, dim=-1)
mc_max = F.pad(mc_max.flip(-1), (1, 0))
mc = (mc_min, mc_max) # (c0, c1, ..., cM), (0, c-1, ..., c-N)
y = self.mglsadf(x, mc)
return y
class MultiStageFIRFilter(nn.Module):
def __init__(
self,
filter_order,
frame_period,
*,
alpha=0,
gamma=0,
ignore_gain=False,
phase="minimum",
taylor_order=20,
cep_order=199,
n_fft=512,
):
super().__init__()
if taylor_order < 0:
raise ValueError("taylor_order must be non-negative.")
self.ignore_gain = ignore_gain
self.phase = phase
self.taylor_order = taylor_order
if alpha == 0 and gamma == 0:
cep_order = filter_order
# Prepare padding module.
if self.phase == "minimum":
padding = (cep_order, 0)
elif self.phase == "maximum":
padding = (0, cep_order)
elif self.phase == "zero":
padding = (cep_order, cep_order)
elif self.phase == "mixed":
padding = cep_order if is_array_like(cep_order) else (cep_order, cep_order)
else:
raise ValueError(f"phase {phase} is not supported.")
self.pad = nn.ConstantPad1d(padding, 0)
# Prepare frequency transformation module.
if self.phase == "mixed":
self.mgc2c = nn.ModuleList()
for i in range(2):
self.mgc2c.append(
MelGeneralizedCepstrumToMelGeneralizedCepstrum(
filter_order[i],
padding[i],
in_alpha=alpha,
in_gamma=gamma,
n_fft=n_fft,
)
)
else:
self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
filter_order,
cep_order,
in_alpha=alpha,
in_gamma=gamma,
n_fft=n_fft,
)
self.linear_intpl = LinearInterpolation(frame_period)
def forward(self, x, mc):
if self.phase == "mixed":
mc_min, mc_max = mc
c_min = self.mgc2c[0](mc_min)
c_max = self.mgc2c[1](mc_max)
c0 = c_min[..., :1] + c_max[..., :1]
c1_min = c_min[..., 1:].flip(-1)
c0_dummy = torch.zeros_like(c0)
c1_max = c_max[..., 1:]
c = torch.cat([c1_min, c0_dummy, c1_max], dim=-1)
else:
c = self.mgc2c(mc)
c0, c = remove_gain(c, value=0, return_gain=True)
if self.phase == "minimum":
c = c.flip(-1)
elif self.phase == "maximum":
pass
elif self.phase == "zero":
c = mirror(c, half=True)
else:
raise RuntimeError
c = self.linear_intpl(c)
y = x.clone()
for a in range(1, self.taylor_order + 1):
x = self.pad(x)
x = x.unfold(-1, c.size(-1), 1)
x = (x * c).sum(-1) / a
y += x
if not self.ignore_gain:
K = torch.exp(self.linear_intpl(c0))
y *= K.squeeze(-1)
return y
class SingleStageFIRFilter(nn.Module):
def __init__(
self,
filter_order,
frame_period,
*,
alpha=0,
gamma=0,
ignore_gain=False,
phase="minimum",
ir_length=2000,
n_fft=4096,
):
super().__init__()
self.ignore_gain = ignore_gain
self.phase = phase
self.n_fft = n_fft
# Prepare padding module.
taps = ir_length - 1
if self.phase == "minimum":
padding = (taps, 0)
elif self.phase == "maximum":
padding = (0, taps)
elif self.phase == "zero":
padding = (taps, taps)
elif self.phase == "mixed":
padding = (
(ir_length[0] - 1, ir_length[1] - 1)
if is_array_like(ir_length)
else (taps, taps)
)
else:
raise ValueError(f"phase {phase} is not supported.")
self.pad = nn.ConstantPad1d(padding, 0)
self.padding = padding
if self.phase in ("minimum", "maximum"):
self.mgc2ir = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
filter_order,
ir_length - 1,
in_alpha=alpha,
in_gamma=gamma,
out_gamma=1,
out_mul=True,
n_fft=n_fft,
)
elif self.phase == "zero":
self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
filter_order,
ir_length - 1,
in_alpha=alpha,
in_gamma=gamma,
n_fft=n_fft,
)
self.c2ir = nn.Sequential(
Lambda(lambda x: torch.fft.hfft(x, n=n_fft)),
Lambda(lambda x: torch.fft.ifft(torch.exp(x)).real[..., :ir_length]),
)
elif self.phase == "mixed":
self.mgc2c = nn.ModuleList()
for i in range(2):
self.mgc2c.append(
MelGeneralizedCepstrumToMelGeneralizedCepstrum(
filter_order[i],
padding[i],
in_alpha=alpha,
in_gamma=gamma,
n_fft=n_fft,
)
)
self.c2ir = CepstrumToMinimumPhaseImpulseResponse(
n_fft - 1, n_fft, n_fft=n_fft
)
else:
raise ValueError(f"phase {phase} is not supported.")
self.linear_intpl = LinearInterpolation(frame_period)
def forward(self, x, mc):
if self.phase == "minimum":
h = self.mgc2ir(mc)
h = h.flip(-1)
elif self.phase == "maximum":
h = self.mgc2ir(mc)
elif self.phase == "zero":
c = self.mgc2c(mc)
c[..., 1:] *= 0.5
if self.ignore_gain:
c = remove_gain(c, value=0)
h = self.c2ir(c)
h = mirror(h)
elif self.phase == "mixed":
mc_min, mc_max = mc
c_min = self.mgc2c[0](mc_min)
c_max = self.mgc2c[1](mc_max)
if self.ignore_gain:
c0 = torch.zeros_like(c_min[..., :1])
else:
c0 = c_min[..., :1] + c_max[..., :1]
c = torch.cat([c_min[..., 1:].flip(-1), c0, c_max[..., 1:]], dim=-1)
c = F.pad(c, (0, self.n_fft - c.size(-1)))
c = torch.roll(c, -self.padding[0], dims=-1)
h = self.c2ir(c)
h = torch.roll(h, self.padding[0], dims=-1)[..., : sum(self.padding) + 1]
else:
raise RuntimeError
h = self.linear_intpl(h)
if self.ignore_gain:
if self.phase == "minimum":
h = h / h[..., -1:]
elif self.phase == "maximum":
h = h / h[..., :1]
x = self.pad(x)
x = x.unfold(-1, h.size(-1), 1)
y = (x * h).sum(-1)
return y
class FrequencyDomainFIRFilter(nn.Module):
def __init__(
self,
filter_order,
frame_period,
*,
alpha=0,
gamma=0,
ignore_gain=False,
phase="minimum",
frame_length=400,
fft_length=512,
n_fft=512,
**stft_kwargs,
):
super().__init__()
if frame_length <= 2 * frame_period:
raise ValueError("frame_period must be less than half of frame_length.")
self.ignore_gain = ignore_gain
self.phase = phase
if self.ignore_gain:
self.gnorm = nn.ModuleList()
self.mc2b = nn.ModuleList()
self.b2mc = nn.ModuleList()
self.mgc2sp = nn.ModuleList()
if not is_array_like(filter_order):
filter_order = (filter_order, filter_order)
n = 2 if phase == "mixed" else 1
for i in range(n):
if self.ignore_gain:
self.gnorm.append(
GeneralizedCepstrumGainNormalization(filter_order[i], gamma=gamma)
)
self.mc2b.append(
MelCepstrumToMLSADigitalFilterCoefficients(
filter_order[i], alpha=alpha
)
)
self.b2mc.append(
MLSADigitalFilterCoefficientsToMelCepstrum(
filter_order[i], alpha=alpha
)
)
self.mgc2sp.append(
MelGeneralizedCepstrumToSpectrum(
filter_order[i],
fft_length,
alpha=alpha,
gamma=gamma,
out_format="complex",
n_fft=n_fft,
)
)
self.stft = ShortTimeFourierTransform(
frame_length, frame_period, fft_length, out_format="complex", **stft_kwargs
)
self.istft = InverseShortTimeFourierTransform(
frame_length, frame_period, fft_length, **stft_kwargs
)
def forward(self, x, mc):
if torch.is_tensor(mc):
mc = [mc]
Hs = []
for i, c in enumerate(mc):
if self.ignore_gain:
b = self.mc2b[i](c)
b = self.gnorm[i](b)
b[..., 0] = 0
c = self.b2mc[i](b)
Hs.append(self.mgc2sp[i](c))
if self.phase == "minimum":
H = Hs[0]
elif self.phase == "maximum":
H = Hs[0].conj()
elif self.phase == "zero":
H = Hs[0].abs()
elif self.phase == "mixed":
H = Hs[0] * Hs[1].conj()
else:
raise RuntimeError
X = self.stft(x)
Y = H * X
y = self.istft(Y, out_length=x.size(-1))
return y