Source code for diffsptk.modules.excite

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch
import torch.nn.functional as F

from ..typing import Precomputed
from ..utils.private import TAU, UNVOICED_SYMBOL, get_values
from .base import BaseFunctionalModule
from .linear_intpl import LinearInterpolation


[docs] class ExcitationGeneration(BaseFunctionalModule): """See `this page <https://sp-nitech.github.io/sptk/latest/main/excite.html>`_ for details. Parameters ---------- frame_period : int >= 1 The frame period in samples, :math:`P`. voiced_region : ['pulse', 'sinusoidal', 'sawtooth', 'inverted-sawtooth', \ 'triangle', 'square'] The type of voiced region. unvoiced_region : ['zeros', 'gauss'] The type of unvoiced region. polarity : ['auto', 'unipolar', 'bipolar'] The polarity. init_phase : ['zeros', 'random'] The initial phase. """ def __init__( self, frame_period: int, *, voiced_region: str = "pulse", unvoiced_region: str = "gauss", polarity: str = "auto", init_phase: str = "zeros", ) -> None: super().__init__() self.values = self._precompute(*get_values(locals()))
[docs] def forward(self, p: torch.Tensor) -> torch.Tensor: """Generate a simple excitation signal. Parameters ---------- p : Tensor [shape=(..., N)] The pitch in seconds. Returns ------- out : Tensor [shape=(..., NxP)] The excitation signal. Examples -------- >>> p = torch.tensor([2.0, 3.0]) >>> excite = diffsptk.ExcitationGeneration(3) >>> e = excite(p) >>> e tensor([1.4142, 0.0000, 1.6330, 0.0000, 0.0000, 1.7321]) """ return self._forward(p, *self.values)
@staticmethod def _func(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: values = ExcitationGeneration._precompute(*args, **kwargs) return ExcitationGeneration._forward(x, *values) @staticmethod def _takes_input_size() -> bool: return False @staticmethod def _check(frame_period: int) -> None: if frame_period <= 0: raise ValueError("frame_period must be positive.") @staticmethod def _precompute( frame_period: int, voiced_region: str, unvoiced_region: str, polarity: str, init_phase: str, ) -> Precomputed: ExcitationGeneration._check(frame_period) return (frame_period, voiced_region, unvoiced_region, polarity, init_phase) @staticmethod @torch.inference_mode() def _forward( p: torch.Tensor, frame_period: int, voiced_region: str, unvoiced_region: str, polarity: str, init_phase: str, ) -> torch.Tensor: # Make mask represents voiced region. base_mask = torch.clip(p, min=0, max=1) mask = torch.ne(base_mask, UNVOICED_SYMBOL) mask = torch.repeat_interleave(mask, frame_period, dim=-1) # Extend right side for interpolation. tmp_mask = F.pad(base_mask, (1, 0)) tmp_mask = torch.eq(torch.diff(tmp_mask), -1) p[tmp_mask] = torch.roll(p, 1, dims=-1)[tmp_mask] # Interpolate pitch. if p.dim() != 1: p = p.transpose(-2, -1) p = LinearInterpolation._func(p, frame_period) if p.dim() != 1: p = p.transpose(-2, -1) p *= mask # Compute phase. voiced_pos = torch.gt(p, 0) q = torch.zeros_like(p) q[voiced_pos] = torch.reciprocal(p[voiced_pos]) s = torch.cumsum(q.double(), dim=-1) bias, _ = torch.cummax(s * ~mask, dim=-1) phase = (s - bias).to(p.dtype) if init_phase == "zeros": pass elif init_phase == "random": phase += torch.rand_like(p[..., :1]) else: raise ValueError(f"init_phase {init_phase} is not supported.") # Generate excitation signal using phase. if polarity == "auto": unipolar = voiced_region == "pulse" elif polarity in ("unipolar", "bipolar"): unipolar = polarity == "unipolar" else: raise ValueError(f"polarity {polarity} is not supported.") e = torch.zeros_like(p) if voiced_region == "pulse": def get_pulse_pos(p): r = torch.ceil(p) r = F.pad(r, (1, 0)) return torch.ge(torch.diff(r), 1) if unipolar: pulse_pos = get_pulse_pos(phase) e[pulse_pos] = torch.sqrt(p[pulse_pos]) else: pulse_pos1 = get_pulse_pos(phase) pulse_pos2 = get_pulse_pos(0.5 * phase) e[pulse_pos1] = torch.sqrt(p[pulse_pos1]) e[pulse_pos1 & ~pulse_pos2] *= -1 elif voiced_region == "sinusoidal": if unipolar: e[mask] = 0.5 * (1 - torch.cos(TAU * phase[mask])) else: e[mask] = torch.sin(TAU * phase[mask]) elif voiced_region == "sawtooth": if unipolar: e[mask] = torch.fmod(phase[mask], 1) else: e[mask] = 2 * torch.fmod(phase[mask], 1) - 1 elif voiced_region == "inverted-sawtooth": if unipolar: e[mask] = 1 - torch.fmod(phase[mask], 1) else: e[mask] = 1 - 2 * torch.fmod(phase[mask], 1) elif voiced_region == "triangle": if unipolar: e[mask] = torch.abs(2 * torch.fmod(phase[mask] + 0.5, 1) - 1) else: e[mask] = 2 * torch.abs(2 * torch.fmod(phase[mask] + 0.75, 1) - 1) - 1 elif voiced_region == "square": if unipolar: e[mask] = torch.le(torch.fmod(phase[mask], 1), 0.5).to(e.dtype) else: e[mask] = 2 * torch.le(torch.fmod(phase[mask], 1), 0.5).to(e.dtype) - 1 else: raise ValueError(f"voiced_region {voiced_region} is not supported.") if unvoiced_region == "zeros": pass elif unvoiced_region == "gauss": e[~mask] = torch.randn(torch.sum(~mask), device=e.device) else: raise ValueError(f"unvoiced_region {unvoiced_region} is not supported.") return e