Source code for diffsptk.modules.pitch_spec

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from ..third_party.world import dc_correction, get_windowed_waveform, linear_smoothing
from ..typing import Callable
from ..utils.private import TAU, iir, next_power_of_two, numpy_to_torch
from .base import BaseNonFunctionalModule
from .frame import Frame
from .spec import Spectrum


[docs] class PitchAdaptiveSpectralAnalysis(BaseNonFunctionalModule): """See `this page <https://sp-nitech.github.io/sptk/latest/main/pitch_spec.html>`_ for details. Note that the gradients do not propagated through F0. Parameters ---------- frame_period : int >= 1 The frame period in samples, :math:`P`. sample_rate : int >= 8000 The sample rate in Hz. fft_length : int >= 1024 The number of FFT bins, :math:`L`. algorithm : ['cheap-trick', 'straight'] The algorithm to estimate spectral envelpe. The STRAIGHT supports only double precision. out_format : ['db', 'log-magnitude', 'magnitude', 'power'] The output format. default_f0 : float > 0 The F0 value used when the input F0 is unvoiced. References ---------- .. [1] M. Morise, "CheapTrick, a spectral envelope estimator for high-quality speech synthesis", *Speech Communication*, vol. 67, pp. 1-7, 2015. .. [2] H. Kawahara et al., "Restructuring speech representations using a pitch-adaptive time-frequency smoothing and an instantaneous-frequency-based F0 extraction: Possible role of a repetitive structure in sounds", *Speech Communication*, vol. 27, no. 3-4, pp. 187-207, 1999. """ def __init__( self, frame_period: int, sample_rate: int, fft_length: int, algorithm: str = "cheap-trick", out_format: str | int = "power", **kwargs, ) -> None: super().__init__() if frame_period <= 0: raise ValueError("frame_period must be positive.") if sample_rate < 8000: raise ValueError("sample_rate must be at least 8000 Hz.") if fft_length < 1024: raise ValueError("fft_length must be at least 1024.") if algorithm == "cheap-trick": self.extractor = SpectrumExtractionByCheapTrick( frame_period, sample_rate, fft_length, **kwargs ) elif algorithm == "straight": self.extractor = SpectrumExtractionBySTRAIGHT( frame_period, sample_rate, fft_length, **kwargs ) else: raise ValueError(f"algorithm {algorithm} is not supported.") self.formatter = self._formatter(out_format)
[docs] def forward(self, x: torch.Tensor, f0: torch.Tensor) -> torch.Tensor: """Estimate spectral envelope. Parameters ---------- x : Tensor [shape=(..., T)] The input waveform. f0 : Tensor [shape=(..., T/P)] The F0 in Hz. Returns ------- out : Tensor [shape=(..., T/P, L/2+1)] The spectral envelope. Examples -------- >>> x = diffsptk.sin(1000, 80) >>> pitch = diffsptk.Pitch(160, 8000, out_format="f0") >>> f0 = pitch(x) >>> f0.shape torch.Size([7]) >>> pitch_spec = diffsptk.PitchAdaptiveSpectralAnalysis(160, 8000, 1024) >>> sp = pitch_spec(x, f0) >>> sp.shape torch.Size([7, 513]) """ sp = self.extractor(x, f0) sp = self.formatter(sp) return sp
@staticmethod def _formatter(out_format: str | int) -> Callable: if out_format in (0, "db"): return lambda x: x * (10 / np.log(10)) elif out_format in (1, "log-magnitude"): return lambda x: x / 2 elif out_format in (2, "magnitude"): return lambda x: torch.exp(x / 2) elif out_format in (3, "power"): return lambda x: torch.exp(x) raise ValueError(f"out_format {out_format} is not supported.")
# ----------------------------------------------------------------- # # Copyright (c) 2010 M. Morise # # # # All rights reserved. # # # # Redistribution and use in source and binary forms, with or # # without modification, are permitted provided that the following # # conditions are met: # # # # - Redistributions of source code must retain the above copyright # # notice, this list of conditions and the following disclaimer. # # - Redistributions in binary form must reproduce the above # # copyright notice, this list of conditions and the following # # disclaimer in the documentation and/or other materials provided # # with the distribution. # # - Neither the name of the M. Morise nor the names of its # # contributors may be used to endorse or promote products derived # # from this software without specific prior written permission. # # # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND # # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, # # INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF # # MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS # # BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED # # TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON # # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # # POSSIBILITY OF SUCH DAMAGE. # # ----------------------------------------------------------------- # class SpectrumExtractionByCheapTrick(nn.Module): """Spectral envelope estimation based on CheapTrick.""" def __init__( self, frame_period: int, sample_rate: int, fft_length: int, *, default_f0: float = 500, q1: float = -0.15, ) -> None: super().__init__() self.frame_period = frame_period self.sample_rate = sample_rate self.fft_length = fft_length # GetF0FloorForCheapTrick() self.f_min = 3 * sample_rate / (fft_length - 3) if default_f0 < self.f_min: raise ValueError(f"default_f0 must be at least {self.f_min}.") # GetFFTSizeForCheapTrick() min_fft_length = 2 ** ( 1 + int(np.log(3 * sample_rate / self.f_min + 1) / np.log(2)) ) if fft_length < min_fft_length: raise ValueError(f"fft_length must be at least {min_fft_length}.") # Set WORLD constants. self.q1 = q1 self.default_f0 = default_f0 self.spec = Spectrum(fft_length) self.register_buffer("ramp", torch.arange(fft_length)) def forward(self, x: torch.Tensor, f0: torch.Tensor) -> torch.Tensor: f0 = torch.where(f0 <= self.f_min, self.default_f0, f0).unsqueeze(-1).detach() # GetWindowedWaveform() waveform = get_windowed_waveform( x, f0, 3, 0, self.frame_period, self.sample_rate, self.fft_length, "hanning", True, 1e-12, self.ramp, ) # GetPowerSpectrum() power_spectrum = self.spec(waveform) # DCCorrection() power_spectrum = dc_correction( power_spectrum, f0, self.sample_rate, self.fft_length, self.ramp ) # LinearSmoothing() power_spectrum = linear_smoothing( power_spectrum, f0 * (2 / 3), self.sample_rate, self.fft_length, self.ramp ) # AddInfinitesimalNoise() power_spectrum += ( torch.randn_like(power_spectrum).abs() * torch.finfo(x.dtype).eps ) # SmoothingWithRecovery() one_sided_length = self.fft_length // 2 + 1 quefrency = self.ramp[:one_sided_length] / self.sample_rate z = f0 * quefrency smoothing_lifter = torch.sinc(z) compensation_lifter = (1 - 2 * self.q1) + 2 * self.q1 * torch.cos(TAU * z) smoothing_lifter[..., 0] = 1 cepstrum = torch.fft.irfft(torch.log(power_spectrum))[..., :one_sided_length] log_power_spectrum = torch.fft.hfft( cepstrum * smoothing_lifter * compensation_lifter, )[..., :one_sided_length] return log_power_spectrum # ------------------------------------------------------------------------ # # Copyright 2018 Hideki Kawahara # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # # You may obtain a copy of the License at # # # # http://www.apache.org/licenses/LICENSE-2.0 # # # # Unless required by applicable law or agreed to in writing, software # # distributed under the License is distributed on an "AS IS" BASIS, # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # See the License for the specific language governing permissions and # # limitations under the License. # # ------------------------------------------------------------------------ # class SpectrumExtractionBySTRAIGHT(nn.Module): """Spectral envelope estimation based on STRAIGHT.""" def __init__( self, frame_period: int, sample_rate: int, fft_length: int, *, default_f0: float = 160, spectral_exponent: float = 0.6, compensation_factor: float = 0.2, ) -> None: super().__init__() self.frame_period = frame_period self.sample_rate = sample_rate self.fft_length = fft_length self.default_f0 = default_f0 self.pc = spectral_exponent self.mag = compensation_factor from scipy import signal b1, a1 = signal.butter(6, 70 / sample_rate * 2, btype="highpass") b2, a2 = signal.butter(6, 300 / sample_rate * 2, btype="highpass") b3, a3 = signal.butter(6, 3000 / sample_rate * 2, btype="highpass") self.register_buffer("b", numpy_to_torch(np.stack([b1, b2, b3]))) self.register_buffer("a", numpy_to_torch(np.stack([a1, a2, a3]))) frame_length_in_msec = 80 frame_length = sample_rate * frame_length_in_msec // 1000 if fft_length < frame_length: raise ValueError(f"fft_length must be at least {frame_length}.") self.frame = Frame(frame_length, frame_period, zmean=True) self.register_buffer("ramp", torch.arange(max(frame_length * 2, fft_length))) tt = (self.ramp[:frame_length] + (1 - frame_length / 2)) / sample_rate self.register_buffer("tt", tt) self.fNominal = 40 eta = 1 wGaussian = torch.exp(-torch.pi * (tt * self.fNominal / eta) ** 2) wSynchronousBartlett = 1 - torch.abs(tt * self.fNominal) wPSGSeed = self.fftfilt( wSynchronousBartlett[0 < wSynchronousBartlett], F.pad(wGaussian, (0, frame_length)), ) maxValue, maxLocation = torch.max(wPSGSeed, dim=-1) wPSGSeed = wPSGSeed / maxValue tNominal = (self.ramp[: 2 * frame_length] - maxLocation) / sample_rate self.register_buffer("wPSGSeed", wPSGSeed) self.register_buffer("tNominal", tNominal) one_sided_length = fft_length // 2 + 1 remaining_length = fft_length - one_sided_length ttm = ( torch.cat( [ self.ramp[:one_sided_length], self.ramp[:remaining_length] - remaining_length, ] ) / sample_rate ) ttm[0] = 1e-5 / sample_rate self.register_buffer("ttm", ttm) lft = torch.sigmoid( ((self.ramp[:fft_length] - fft_length // 2).abs() - fft_length / 30) / 2 ) self.register_buffer("lft", lft) from pylstraight.core.sp import optimumsmoothing as optimum_smoothing ovc = optimum_smoothing(eta, self.pc) self.register_buffer("ovc", numpy_to_torch(ovc)) ncw = round(2 * sample_rate / 1000) h3 = signal.convolve( np.hanning(ncw // 2 + 2)[1:-1], np.exp(-1400 / sample_rate * np.arange(2 * ncw + 1)), mode="full", ) self.register_buffer("h3", numpy_to_torch(h3)) ipwm = 7 ipl = round(ipwm / (frame_period / sample_rate * 1000)) ww = np.hanning(ipl * 2 + 3)[1:-1] ww /= np.sum(ww) self.register_buffer("ww", numpy_to_torch(ww)) hh = np.array( [ [1, 1, 1, 1], [0, 1 / 2, 2 / 3, 3 / 4], [0, 0, 1 / 3, 2 / 4], [0, 0, 0, 1 / 4], ] ) bb = np.linalg.solve(hh, ovc) cc = np.array([1, 4, 9, 16]) tt = np.arange(one_sided_length) / sample_rate pb2 = (np.pi / eta**2 + np.pi**2 / 3 * np.sum(bb * cc)) * tt**2 self.register_buffer("pb2", numpy_to_torch(pb2)) @staticmethod def fftfilt(b: torch.Tensor, x: torch.Tensor) -> torch.Tensor: nb = b.size(-1) nx = x.size(-1) fft_length = next_power_of_two(nb + nx - 1) B = torch.fft.fft(b, n=fft_length) X = torch.fft.fft(x, n=fft_length) y = torch.fft.ifft(X * B)[..., :nx] return y.real def forward(self, x: torch.Tensor, f0: torch.Tensor) -> torch.Tensor: if x.dtype != torch.double or self.a.dtype != torch.double: raise ValueError("Only double precision is supported.") xamp = torch.std(x, dim=-1, keepdim=True) scaleconst = 2200 x = torch.where(xamp < 1e-10, x, x * (scaleconst / xamp)) xh = iir(x, self.b, self.a, batching=False) tx = self.frame(xh[..., 0, :]) f0 = f0.unsqueeze(-1).detach() f0raw = f0 unvoiced = f0 == 0 f0 = torch.where(unvoiced, self.default_f0, f0) ttf = self.tt * f0 def safe_div(x, y, eps=1e-10): return x / (y + eps) # https://github.com/pytorch/pytorch/issues/50334#issuecomment-2304751532 def interp1(x, y, xq, method="linear", batching=(False, False)): if not batching[0]: x = x.repeat(*xq.shape[0:-1], 1) if not batching[1]: y = y.repeat(*xq.shape[0:-1], 1) m = torch.diff(y) / torch.diff(x) b = y[..., :-1] - m * x[..., :-1] indices = torch.searchsorted(x, xq, right=False) if method == "linear": m = F.pad(m, (1, 1)) b = torch.cat([y[..., :1], b, y[..., -1:]], dim=-1) elif method == "*linear": indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1) values = m.gather(-1, indices) * xq + b.gather(-1, indices) return values wxe = interp1( self.tNominal, self.wPSGSeed, ttf / self.fNominal, method="*linear" ) wxe /= torch.linalg.vector_norm(wxe, dim=-1, keepdim=True) bcf = 0.36 wxd = bcf * wxe * torch.sin(torch.pi * ttf) one_sided_length = self.fft_length // 2 + 1 pw = ( torch.fft.rfft(tx * wxe, n=self.fft_length).abs() ** 2 + torch.fft.rfft(tx * wxd, n=self.fft_length).abs() ** 2 ) pw = torch.clip(pw, min=1e-6) ** (self.pc / 2) f0pr = f0 * (self.fft_length / self.sample_rate) + 1 f0p = torch.ceil(f0pr).long() f0p2 = torch.floor((f0pr + 1) / 2).long() f0pm = f0p.max() f0p2m = f0p2.max() pwx = self.ramp[:f0pm] + 1 pwxq = f0pr - self.ramp[:f0p2m] tmppw = interp1( pwx, pw[..., :f0pm], pwxq, method="linear", batching=(False, True) ) tmppw = F.pad(tmppw, (0, one_sided_length - f0p2m)) mask = self.ramp[:one_sided_length] < f0p2 pw = torch.where(mask, tmppw, pw) ttmf = self.ttm * f0 ww2t = torch.sinc(3 * ttmf) ** 2 spw2 = torch.fft.ihfft(ww2t * torch.fft.hfft(pw) * self.lft).real wwt = torch.sinc(ttmf) ** 2 wwt *= ( self.ovc[0] + self.ovc[1] * 2 * torch.cos(TAU * ttmf) + self.ovc[2] * 2 * torch.cos(2 * TAU * ttmf) ) spw = ( torch.fft.ihfft(wwt * torch.fft.hfft(safe_div(pw, spw2)) * self.lft).real / wwt[..., :1] ) n2sgram = spw2 * ( 0.175 * torch.log(2 * torch.cosh(4 / 1.4 * spw) + 1e-10) + 0.5 * spw ) n2sgram = torch.clip(n2sgram, min=1e-6) ** (2 / self.pc) nframe = f0.size(-2) pwcs = self.fftfilt( self.h3, F.pad(xh[..., 1:, :].abs() ** 2, (0, 4 * len(self.h3))) ) end = self.frame_period * nframe pwcs = pwcs[..., : end : self.frame_period] lbb = round(300 / self.sample_rate * self.fft_length) - 1 numer = torch.cat( [ torch.sum(n2sgram[..., lbb:], dim=(-1, -2), keepdim=True), torch.sum(n2sgram, dim=(-1, -2), keepdim=True), ], dim=-2, ) denom = torch.sum(pwcs, dim=-1, keepdim=True) pwcs = pwcs * safe_div(numer, denom) pwch = pwcs[..., 1, :] apwt = self.fftfilt(self.ww, F.pad(pwch, (0, len(self.ww)))) begin = len(self.ww) // 2 apwt = apwt[..., begin : begin + nframe] mmaa = torch.amax(apwt, dim=-1, keepdim=True) apwt = torch.where(apwt <= 0, mmaa, apwt) dpwt = self.fftfilt(self.ww, F.pad(torch.diff(pwch) ** 2, (0, len(self.ww)))) dpwt = dpwt[..., begin : begin + nframe] dpwt = torch.sqrt(torch.clip(dpwt, min=1e-10)) rr = safe_div(dpwt, apwt) lmbd = torch.sigmoid((torch.sqrt(rr) - 0.75) * 20) pwc = lmbd * safe_div(pwcs[..., 0, :], torch.sum(n2sgram, dim=-1)) + (1 - lmbd) n2sgram = torch.where(unvoiced, n2sgram * pwc.unsqueeze(-1), n2sgram) n2sgram = torch.sqrt(torch.abs(n2sgram + 1e-10)) if 0 < self.mag: ccs2 = torch.fft.hfft(n2sgram)[..., :one_sided_length] * torch.clip( 1 + self.mag * self.pb2 * f0raw**2, max=20 ) n2sgram3 = torch.fft.hfft(ccs2, norm="forward")[..., :one_sided_length] n2sgram = (n2sgram3.abs() + n2sgram3) / 2 + 0.1 xamp = xamp.unsqueeze(-1) n3sgram = torch.where(xamp < 1e-10, n2sgram, n2sgram * (xamp / scaleconst)) n3sgram = 2 * torch.log(torch.abs(n3sgram + 1e-10)) return n3sgram