Source code for diffsptk.modules.fftr

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch
import torch.nn.functional as F
from torch import nn

from ..typing import Callable, Precomputed
from ..utils.private import get_values, to
from .base import BaseFunctionalModule


[docs] class RealValuedFastFourierTransform(BaseFunctionalModule): """See `this page <https://sp-nitech.github.io/sptk/latest/main/fftr.html>`_ for details. Parameters ---------- fft_length : int >= 2 The FFT length, :math:`L`. out_format : ['complex', 'real', 'imaginary', 'amplitude', 'power'] The output format. learnable : bool Whether to make the DFT basis learnable. If True, the module performs DFT rather than FFT. """ def __init__( self, fft_length: int, out_format: str | int = "complex", learnable: bool = False, ) -> None: super().__init__() self.values, _, tensors = self._precompute(*get_values(locals())) if learnable is True: self.W = nn.Parameter(tensors[0]) elif learnable == "debug": self.register_buffer("W", tensors[0])
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute FFT of a real signal. Parameters ---------- x : Tensor [shape=(..., N)] The real input signal. Returns ------- out : Tensor [shape=(..., L/2+1)] The output spectrum. Examples -------- >>> x = diffsptk.ramp(1, 3) >>> x tensor([1., 2., 3.]) >>> fftr = diffsptk.RealValuedFastFourierTransform(8, out_format="real") >>> y = fftr(x) >>> y tensor([ 6.0000, 2.4142, -2.0000, -0.4142, 2.0000]) """ return self._forward(x, *self.values, **self._buffers, **self._parameters)
@staticmethod def _func(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: values, _, _ = RealValuedFastFourierTransform._precompute(*args, **kwargs) return RealValuedFastFourierTransform._forward(x, *values) @staticmethod def _takes_input_size() -> bool: return False @staticmethod def _check(fft_length: int | None) -> None: if fft_length is not None and (fft_length <= 0 or fft_length % 2 == 1): raise ValueError("fft_length must be positive even.") @staticmethod def _precompute( fft_length: int | None, out_format: str | int, learnable: bool = False, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> Precomputed: RealValuedFastFourierTransform._check(fft_length) if out_format in (0, "complex"): formatter = lambda x: x elif out_format in (1, "real"): formatter = lambda x: x.real elif out_format in (2, "imaginary"): formatter = lambda x: x.imag elif out_format in (3, "amplitude"): formatter = lambda x: x.abs() elif out_format in (4, "power"): formatter = lambda x: x.abs().square() else: raise ValueError(f"out_format {out_format} is not supported.") if learnable: W = torch.fft.fft(torch.eye(fft_length, device=device, dtype=torch.double)) W = W[..., : fft_length // 2 + 1] W = torch.cat([W.real, W.imag], dim=-1) tensors = (to(W, dtype=dtype),) else: tensors = None return (fft_length, formatter), None, tensors @staticmethod def _forward( x: torch.Tensor, fft_length: int | None, formatter: Callable, W: torch.Tensor | None = None, ) -> torch.Tensor: if W is None: y = torch.fft.rfft(x, n=fft_length) else: if fft_length is not None and fft_length != x.size(-1): x = F.pad(x, (0, fft_length - x.size(-1))) y = torch.matmul(x, W) y = torch.complex(*torch.tensor_split(y, 2, dim=-1)) return formatter(y)