# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
import torch
import torch.nn.functional as F
from ..typing import Precomputed
from ..utils.private import get_values, remove_gain
from .base import BaseFunctionalModule
[docs]
class GroupDelay(BaseFunctionalModule):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/grpdelay.html>`_
for details.
Parameters
----------
fft_length : int >= 2
The number of FFT bins, :math:`L`.
alpha : float > 0
The tuning parameter, :math:`\\alpha`.
gamma : float > 0
The tuning parameter, :math:`\\gamma`.
"""
def __init__(
self,
fft_length: int,
alpha: float = 1,
gamma: float = 1,
) -> None:
super().__init__()
self.values, _, tensors = self._precompute(*get_values(locals()))
self.register_buffer("ramp", tensors[0])
[docs]
def forward(
self, b: torch.Tensor | None = None, a: torch.Tensor | None = None
) -> torch.Tensor:
"""Compute group delay.
Parameters
----------
b : Tensor [shape=(..., M+1)] or None
The numerator coefficients.
a : Tensor [shape=(..., N+1)] or None
The denominator coefficients.
Returns
-------
out : Tensor [shape=(..., L/2+1)]
The group delay or modified group delay function.
Examples
--------
>>> x = diffsptk.ramp(3)
>>> grpdelay = diffsptk.GroupDelay(8)
>>> g = grpdelay(x)
>>> g
tensor([2.3333, 2.4278, 3.0000, 3.9252, 3.0000])
"""
return self._forward(b, a, *self.values, **self._buffers)
@staticmethod
def _func(
b: torch.Tensor | None, a: torch.Tensor | None, *args, **kwargs
) -> torch.Tensor:
x = a if b is None else b
values, _, tensors = GroupDelay._precompute(
*args, **kwargs, device=x.device, dtype=x.dtype
)
return GroupDelay._forward(b, a, *values, *tensors)
@staticmethod
def _takes_input_size() -> bool:
return False
@staticmethod
def _check(fft_length: int, alpha: float, gamma: float) -> None:
if fft_length <= 1:
raise ValueError("fft_length must be greater than 1.")
if alpha <= 0:
raise ValueError("alpha must be positive.")
if gamma <= 0:
raise ValueError("gamma must be positive.")
@staticmethod
def _precompute(
fft_length: int,
alpha: float,
gamma: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Precomputed:
GroupDelay._check(fft_length, alpha, gamma)
ramp = torch.arange(fft_length, device=device, dtype=dtype)
return (fft_length, alpha, gamma), None, (ramp,)
@staticmethod
def _forward(
b: torch.Tensor | None,
a: torch.Tensor | None,
fft_length: int,
alpha: float,
gamma: float,
ramp: torch.Tensor,
) -> torch.Tensor:
if b is None and a is None:
raise ValueError("Either b or a must be specified.")
if a is None:
order = 0
else:
a = remove_gain(a)
order = a.size(-1) - 1
if b is None:
c = a.flip(-1)
elif a is None:
c = b
else:
# Perform full convolution.
b1 = F.pad(b, (order, order))
b2 = b1.unfold(-1, b.size(-1) + order, 1)
c = (b2 * a.unsqueeze(-1)).sum(-2)
data_length = c.size(-1)
if fft_length < data_length:
raise RuntimeError("Please increase FFT length.")
d = c * ramp[:data_length]
C = torch.fft.rfft(c, n=fft_length)
D = torch.fft.rfft(d, n=fft_length)
numer = C.real * D.real + C.imag * D.imag
denom = C.real * C.real + C.imag * C.imag
if gamma != 1:
denom = torch.pow(denom, gamma)
g = numer / denom - order
if alpha != 1:
g = torch.sign(g) * torch.pow(torch.abs(g), alpha)
return g