Source code for diffsptk.modules.hilbert2

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch

from ..typing import ArrayLike, Precomputed
from ..utils.private import get_values, to
from .base import BaseFunctionalModule
from .hilbert import HilbertTransform


[docs] class TwoDimensionalHilbertTransform(BaseFunctionalModule): """2-D Hilbert transform module. Parameters ---------- fft_length : int >= 1 or tuple[int, int] The number of FFT bins. dim : tuple[int, int] The dimension along which to take the Hilbert transform. """ def __init__( self, fft_length: ArrayLike[int] | int, dim: ArrayLike[int] = (-2, -1) ) -> None: super().__init__() self.values, _, tensors = self._precompute(*get_values(locals())) self.register_buffer("h", tensors[0])
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute the analytic signal using the Hilbert transform. Parameters ---------- x : Tensor [shape=(..., T1, T2, ...)] The input signal. Returns ------- out : Tensor [shape=(..., T1, T2, ...)] The analytic signal, where the real part is the input signal and the imaginary part is the Hilbert transform of the input signal. Examples -------- >>> x = diffsptk.nrand(3) >>> x tensor([[ 1.1809, -0.2834, -0.4169, 0.3883]]) >>> hilbert2 = diffsptk.TwoDimensionalHilbertTransform((1, 4)) >>> z = hilbert2(x) >>> z.real tensor([[ 1.1809, -0.2834, -0.4169, 0.3883]]) >>> z.imag tensor([[ 0.3358, 0.7989, -0.3358, -0.7989]]) """ return self._forward(x, *self.values, **self._buffers)
@staticmethod def _func( x: torch.Tensor, fft_length: ArrayLike[int] | int | None, dim: ArrayLike[int] ) -> torch.Tensor: values, _, tensors = TwoDimensionalHilbertTransform._precompute( (x.size(dim[0]), x.size(dim[1])) if fft_length is None else fft_length, dim, device=x.device, dtype=x.dtype, ) return TwoDimensionalHilbertTransform._forward(x, *values, *tensors) @staticmethod def _takes_input_size() -> bool: return True @staticmethod def _check(dim: ArrayLike[int]) -> None: if len(dim) != 2: raise ValueError("dim must have length 2.") @staticmethod def _precompute( fft_length: ArrayLike[int] | int, dim: ArrayLike[int], device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> Precomputed: TwoDimensionalHilbertTransform._check(dim) if isinstance(fft_length, int): fft_length = (fft_length, fft_length) _, _, h1 = HilbertTransform._precompute( fft_length[0], None, device=device, dtype=torch.double ) _, _, h2 = HilbertTransform._precompute( fft_length[1], None, device=device, dtype=torch.double ) h = h1[0].unsqueeze(1) * h2[0].unsqueeze(0) return (dim,), None, (to(h, dtype=dtype),) @staticmethod def _forward(x: torch.Tensor, dim: ArrayLike[int], h: torch.Tensor) -> torch.Tensor: L = h.size(dim[0]), h.size(dim[1]) target_shape = [1] * x.dim() target_shape[dim[0]] = L[0] target_shape[dim[1]] = L[1] h = h.view(*target_shape) X = torch.fft.fft2(x, s=L, dim=dim) z = torch.fft.ifft2(X * h, s=L, dim=dim) return z