Source code for diffsptk.modules.grpdelay

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch
from torch import nn
import torch.nn.functional as F

from ..misc.utils import remove_gain
from ..misc.utils import to


[docs] class GroupDelay(nn.Module): """See `this page <https://sp-nitech.github.io/sptk/latest/main/grpdelay.html>`_ for details. Parameters ---------- fft_length : int >= 2 Number of FFT bins, :math:`L`. alpha : float > 0 Tuning parameter, :math:`\\alpha`. gamma : float > 0 Tuning parameter, :math:`\\gamma`. """ def __init__(self, fft_length, alpha=1, gamma=1): super().__init__() assert 2 <= fft_length assert 0 < alpha assert 0 < gamma self.fft_length = fft_length self.alpha = alpha self.gamma = gamma self.register_buffer("ramp", self._precompute(self.fft_length))
[docs] def forward(self, b=None, a=None): """Compute group delay. Parameters ---------- b : Tensor [shape=(..., M+1)] or None Numerator coefficients. a : Tensor [shape=(..., N+1)] or None Denominator coefficients. Returns ------- out : Tensor [shape=(..., L/2+1)] Group delay or modified group delay function. Examples -------- >>> x = diffsptk.ramp(3) >>> grpdelay = diffsptk.GroupDelay(8) >>> g = grpdelay(x) >>> g tensor([2.3333, 2.4278, 3.0000, 3.9252, 3.0000]) """ return self._forward( b, a, self.fft_length, self.alpha, self.gamma, self.ramp, )
@staticmethod def _forward(b, a, fft_length, alpha, gamma, ramp): if b is None and a is None: raise ValueError("Either b or a must be specified.") if a is None: order = 0 else: a = remove_gain(a) order = a.size(-1) - 1 if b is None: c = a.flip(-1) elif a is None: c = b else: # Perform full convolution. b1 = F.pad(b, (order, order)) b2 = b1.unfold(-1, b.size(-1) + order, 1) c = (b2 * a.unsqueeze(-1)).sum(-2) data_length = c.size(-1) if fft_length < data_length: raise RuntimeError("Please increase FFT length.") d = c * ramp[:data_length] C = torch.fft.rfft(c, n=fft_length) D = torch.fft.rfft(d, n=fft_length) numer = C.real * D.real + C.imag * D.imag denom = C.real * C.real + C.imag * C.imag if gamma != 1: denom = torch.pow(denom, gamma) g = numer / denom - order if alpha != 1: g = torch.sign(g) * torch.pow(torch.abs(g), alpha) return g @staticmethod def _func(b, a, fft_length, alpha, gamma): if b is not None and a is not None: data_length = b.size(-1) + a.size(-1) - 1 elif b is not None: data_length = b.size(-1) else: data_length = a.size(-1) ramp = GroupDelay._precompute( data_length, dtype=a.dtype if b is None else b.dtype, device=a.device if b is None else b.device, ) return GroupDelay._forward(b, a, fft_length, alpha, gamma, ramp) @staticmethod def _precompute(length, dtype=None, device=None): ramp = torch.arange(length, device=device) return to(ramp, dtype=dtype)