Source code for diffsptk.modules.hilbert2
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
import torch
from torch import nn
from ..misc.utils import to
from .hilbert import HilbertTransform
[docs]
class TwoDimensionalHilbertTransform(nn.Module):
"""2-D Hilbert transform module.
Parameters
----------
fft_length : int or list[int]
Number of FFT bins.
dim : list[int]
Dimensions along which to take the Hilbert transform.
"""
def __init__(self, fft_length, dim=(-2, -1)):
super().__init__()
assert len(dim) == 2
self.dim = dim
self.register_buffer("h", self._precompute(fft_length))
[docs]
def forward(self, x):
"""Compute analytic signal using the Hilbert transform.
Parameters
----------
x : Tensor [shape=(..., T1, T2, ...)]
Input signal.
Returns
-------
out : Tensor [shape=(..., T1, T2, ...)]
Analytic signal, where real part is the input signal and imaginary part is
the Hilbert transform of the input signal.
Examples
--------
>>> x = diffsptk.nrand(3)
>>> x
tensor([[ 1.1809, -0.2834, -0.4169, 0.3883]])
>>> hilbert2 = diffsptk.TwoDimensionalHilbertTransform((1, 4))
>>> z = hilbert2(x)
>>> z.real
tensor([[ 1.1809, -0.2834, -0.4169, 0.3883]])
>>> z.imag
tensor([[ 0.3358, 0.7989, -0.3358, -0.7989]])
"""
return self._forward(x, self.h, self.dim)
@staticmethod
def _forward(x, h, dim):
L = h.size(dim[0]), h.size(dim[1])
target_shape = [1] * x.dim()
target_shape[dim[0]] = L[0]
target_shape[dim[1]] = L[1]
h = h.view(*target_shape)
X = torch.fft.fft2(x, s=L, dim=dim)
z = torch.fft.ifft2(X * h, s=L, dim=dim)
return z
@staticmethod
def _func(x, fft_length, dim):
if fft_length is None:
fft_length = (x.size(dim[0]), x.size(dim[1]))
h = TwoDimensionalHilbertTransform._precompute(
fft_length, dtype=x.dtype, device=x.device
)
return TwoDimensionalHilbertTransform._forward(x, h, dim)
@staticmethod
def _precompute(fft_length, dtype=None, device=None):
if isinstance(fft_length, int):
fft_length = (fft_length, fft_length)
h1 = HilbertTransform._precompute(
fft_length[0], dtype=torch.double, device=device
)
h2 = HilbertTransform._precompute(
fft_length[1], dtype=torch.double, device=device
)
h = h1.unsqueeze(1) * h2.unsqueeze(0)
return to(h, dtype=dtype)