# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
# ----------------------------------------------------------------- #
# Copyright (c) 2010 M. Morise #
# #
# All rights reserved. #
# #
# Redistribution and use in source and binary forms, with or #
# without modification, are permitted provided that the following #
# conditions are met: #
# #
# - Redistributions of source code must retain the above copyright #
# notice, this list of conditions and the following disclaimer. #
# - Redistributions in binary form must reproduce the above #
# copyright notice, this list of conditions and the following #
# disclaimer in the documentation and/or other materials provided #
# with the distribution. #
# - Neither the name of the M. Morise nor the names of its #
# contributors may be used to endorse or promote products derived #
# from this software without specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND #
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, #
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF #
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS #
# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, #
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED #
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, #
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON #
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY #
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE #
# POSSIBILITY OF SUCH DAMAGE. #
# ----------------------------------------------------------------- #
import torch
from ..third_party.world import get_minimum_phase_spectrum, interp1
from ..utils.private import TAU, to
from .base import BaseNonFunctionalModule
[docs]
class WorldSynthesis(BaseNonFunctionalModule):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/world_synth.html>`_
for details. Note that the gradients do not propagated through F0.
Parameters
----------
frame_period : int >= 1
The frame period in samples, :math:`P`.
sample_rate : int >= 8000
The sample rate in Hz.
fft_length : int >= 1024
The number of FFT bins, :math:`L`.
default_f0 : float > 0
The F0 value used when the input F0 is unvoiced.
device : torch.device or None
The device of this module.
dtype : torch.dtype or None
The data type of this module.
"""
def __init__(
self,
frame_period: int,
sample_rate: int,
fft_length: int,
*,
default_f0: float = 500,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__()
if frame_period <= 0:
raise ValueError("frame_period must be positive.")
if sample_rate < 8000:
raise ValueError("sample_rate must be at least 8000 Hz.")
if fft_length < 1024:
raise ValueError("fft_length must be at least 1024.")
self.frame_period = frame_period
self.sample_rate = sample_rate
self.fft_length = fft_length
self.default_f0 = default_f0
self.register_buffer("ramp", torch.arange(fft_length, device=device))
# GetDCRemover()
ramp = self.ramp[1 : fft_length // 2 + 1].double()
dc_remover = 0.5 - 0.5 * torch.cos(TAU / (1 + fft_length) * ramp)
dc_component = 2 * torch.sum(dc_remover)
dc_remover /= dc_component
dc_remover = torch.cat([dc_remover, dc_remover.flip(-1)], dim=-1)
self.register_buffer("dc_remover", to(dc_remover, dtype=dtype))
[docs]
def forward(
self,
f0: torch.Tensor,
ap: torch.Tensor,
sp: torch.Tensor,
out_length: int | None = None,
) -> torch.Tensor:
"""Synthesize speech using WORLD vocoder.
Parameters
----------
f0 : Tensor [shape=(B, T/P) or (T/P,)]
The F0 in Hz.
ap : Tensor [shape=(B, T/P, L/2+1) or (T/P, L/2+1)]
The aperiodicity in [0, 1].
sp : Tensor [shape=(B, T/P, L/2+1) or (T/P, L/2+1)]
The spectral envelope (power spectrum).
out_length : int > 0 or None
The length of the output waveform.
Returns
-------
out : Tensor [shape=(B, T) or (T,)]
The synthesized speech waveform.
Examples
--------
>>> import diffsptk
>>> pitch = diffsptk.Pitch(160, 16000, out_format="f0")
>>> aperiodicity = diffsptk.Aperiodicity(160, 16000, 1024)
>>> spec = diffsptk.PitchAdaptiveSpectralAnalysis(160, 16000, 1024)
>>> world_synth = diffsptk.WorldSynthesis(160, 16000, 1024)
>>> x = diffsptk.sin(2000 - 1, 80)
>>> f0 = pitch(x)
>>> ap = aperiodicity(x, f0)
>>> sp = spec(x, f0)
>>> y = world_synth(f0, ap, sp, out_length=x.size(0))
>>> y.shape
torch.Size([2000])
"""
is_batched_input = f0.ndim == 2
if not is_batched_input:
f0 = f0.unsqueeze(0)
ap = ap.unsqueeze(0)
sp = sp.unsqueeze(0)
# Check the input shape.
if f0.dim() != 2:
raise ValueError("f0 must be 1D or 2D tensor.")
if ap.dim() != 3 or sp.dim() != 3:
raise ValueError("ap and sp must be 2D or 3D tensor.")
if len(set([f0.shape[0], ap.shape[0], sp.shape[0]])) != 1:
raise ValueError("f0, ap, and sp must have the same batch size.")
if len(set([f0.shape[1], ap.shape[1], sp.shape[1]])) != 1:
raise ValueError("f0, ap, and sp must have the same length.")
if len(set([ap.shape[2], sp.shape[2]])) != 1:
raise ValueError("ap and sp must have the same dimension.")
# Get the input shape.
B, N, D = sp.shape
T = N * self.frame_period
# Restrict the input range.
eps = 1e-6
ap = torch.clip(ap, min=eps, max=1 - eps)
sp = torch.clip(sp, min=eps)
# GetTemporalParametersForTimeBase()
f_min = self.sample_rate / self.fft_length + 1
coarse_f0 = torch.where(f0 < f_min, 0, f0).detach()
coarse_vuv = (0 < coarse_f0).type(coarse_f0.dtype)
time_axis = (
torch.arange(
f0.shape[-1] * self.frame_period, device=f0.device, dtype=f0.dtype
)
/ self.sample_rate
)
time_axis = time_axis.repeat(B, 1)
coarse_time_axis = torch.arange(
coarse_f0.shape[-1], device=coarse_f0.device, dtype=coarse_f0.dtype
) * (self.frame_period / self.sample_rate)
coarse_time_axis = coarse_time_axis.repeat(B, 1)
interpolated_f0 = interp1(
coarse_time_axis, coarse_f0, time_axis, batching=(True, True)
)
interpolated_vuv = interp1(
coarse_time_axis, coarse_vuv, time_axis, batching=(True, True)
)
interpolated_vuv = 0.5 < interpolated_vuv
interpolated_f0 = torch.where(
interpolated_vuv, interpolated_f0, self.default_f0
)
# GetPulseLocationsForTimeBase()
total_phase = torch.cumsum(
TAU / self.sample_rate * interpolated_f0.double(), dim=-1
).type(f0.dtype)
wrap_phase = torch.fmod(total_phase, TAU)
wrap_phase_abs = torch.abs(torch.diff(wrap_phase))
pulse_locations_index = torch.nonzero(torch.pi < wrap_phase_abs, as_tuple=True)
pulse_locations = time_axis[pulse_locations_index]
vuv = interpolated_vuv[pulse_locations_index].unsqueeze(-1)
batch_index, time_index = pulse_locations_index
y1 = wrap_phase[pulse_locations_index] - TAU
y2 = wrap_phase[batch_index, time_index + 1]
pulse_locations_time_shift = -y1 / (y2 - y1) / self.sample_rate
# GetSpectralEnvelope()
frame = pulse_locations * (self.sample_rate / self.frame_period)
frame_floor = frame.floor().long().clip(max=N - 1)
frame_ceil = frame.ceil().long().clip(max=N - 1)
interpolation = (frame - frame_floor).unsqueeze(-1)
lower_weight = 1 - interpolation
upper_weight = interpolation
spectral_envelope = (
lower_weight * sp[batch_index, frame_floor]
+ upper_weight * sp[batch_index, frame_ceil]
)
# GetAperiodicRatio()
aperiodic_ratio = (
lower_weight * ap[batch_index, frame_floor]
+ upper_weight * ap[batch_index, frame_ceil]
) ** 2
# GetPeriodicResponse()
weight = 1 - aperiodic_ratio
spectrum = get_minimum_phase_spectrum(weight * spectral_envelope)
# GetSpectrumWithFractionalTimeShift()
coefficient = (
TAU * self.sample_rate / self.fft_length * pulse_locations_time_shift
)
phase = torch.exp(-1j * self.ramp[:D] * coefficient.unsqueeze(-1))
periodic_response = torch.fft.hfft(spectrum * phase)
periodic_response = torch.cat(
[periodic_response[..., :1], periodic_response[..., 1:].flip(-1)], dim=-1
)
periodic_response = torch.fft.fftshift(periodic_response, dim=-1)
# RemoveDCComponent()
H = self.fft_length // 2
dc_component = periodic_response[..., H:].sum(-1, keepdim=True)
dd = -dc_component * self.dc_remover
periodic_response = torch.cat(
(dd[..., :H], periodic_response[..., H:] + dd[..., H:]), dim=-1
)
periodic_response = periodic_response * (0.5 < vuv)
# GetNoiseSpectrum()
noise_size = torch.diff(time_index, append=time_index[-1:])
noise_size = noise_size.clip(min=0).unsqueeze(-1)
noise_waveform = torch.randn_like(periodic_response)
mask = self.ramp < noise_size
noise_waveform = noise_waveform * mask
average = noise_waveform.sum(dim=-1, keepdim=True) / noise_size
average = torch.nan_to_num(average)
noise_waveform = (noise_waveform - average) * mask
noise_spectrum = torch.fft.rfft(noise_waveform)
# GetAperiodicResponse()
weight = torch.where(0 < vuv, aperiodic_ratio, 1)
spectrum = (
get_minimum_phase_spectrum(weight * spectral_envelope) * noise_spectrum
)
aperiodic_response = torch.fft.hfft(spectrum)
aperiodic_response = torch.cat(
[aperiodic_response[..., :1], aperiodic_response[..., 1:].flip(-1)], dim=-1
)
aperiodic_response = torch.fft.fftshift(aperiodic_response, dim=-1)
# Synthesis()
sqrt_noise_size = torch.sqrt(noise_size)
response = (
periodic_response * sqrt_noise_size + aperiodic_response
) / self.fft_length
margin = (
(self.fft_length + self.frame_period - 1)
// self.frame_period
* self.frame_period
)
T_ = T + margin
index = (batch_index * T_ + time_index).unsqueeze(-1) + self.ramp
y = torch.zeros((B, T_), device=sp.device, dtype=sp.dtype)
y.view(-1).scatter_add_(
dim=-1,
index=index.view(-1),
src=response.view(-1),
)
y = torch.narrow(y, dim=-1, start=H, length=T)
if not is_batched_input:
y = y.squeeze(0)
if out_length is not None:
y = y[..., :out_length]
return y