Source code for diffsptk.modules.unframe

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch.nn.functional as F
from torch import nn

from .window import Window


[docs] class Unframe(nn.Module): """This is the opposite module to :func:`~diffsptk.Frame`. Parameters ---------- frame_length : int >= 1 Frame length, :math:`L`. frame_peirod : int >= 1 Frame period, :math:`P`. center : bool If True, assume that the center of data is the center of frame, otherwise assume that the center of data is the left edge of frame. window : ['blackman', 'hamming', 'hanning', 'bartlett', 'trapezoidal', \ 'rectangular'] Window type. norm : ['none', 'power', 'magnitude'] Normalization type of window. """ def __init__( self, frame_length, frame_period, *, center=True, window="rectangular", norm="none", ): super().__init__() assert 1 <= frame_period <= frame_length self.frame_length = frame_length self.frame_period = frame_period self.center = center self.register_buffer( "window", Window._precompute(self.frame_length, window, norm).view(1, -1, 1), )
[docs] def forward(self, y, out_length=None): """Revert framed waveform. Parameters ---------- y : Tensor [shape=(..., T/P, L)] Framed waveform. out_length : int or None Length of original signal, `T`. Returns ------- out : Tensor [shape=(..., T)] Waveform. Examples -------- >>> x = diffsptk.ramp(1, 9) >>> frame = diffsptk.Frame(5, 2) >>> y = frame(x) >>> y tensor([[0., 0., 1., 2., 3.], [1., 2., 3., 4., 5.], [3., 4., 5., 6., 7.], [5., 6., 7., 8., 9.], [7., 8., 9., 0., 0.]]) >>> unframe = diffsptk.Unframe(5, 2) >>> z = unframe(y, out_length=x.size(0)) >>> z tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.]) """ return self._forward( y, out_length, self.frame_period, self.center, self.window, )
@staticmethod def _forward(y, out_length, frame_period, center, window): frame_length = window.size(-2) d = y.dim() N = y.size(-2) assert 2 <= d <= 4, "Input must be 2D, 3D, or 4D tensor." def fold(x): x = F.fold( x, (1, (N - 1) * frame_period + frame_length), (1, frame_length), stride=(1, frame_period), ) s = frame_length // 2 if center else 0 e = None if out_length is None else s + out_length x = x[..., 0, 0, s:e] return x w = window.repeat(1, 1, N) x = y.transpose(-2, -1) if d == 2: x = x.unsqueeze(0) w = fold(w) x = fold(x) x = x / w if d == 2: x = x.squeeze(0) return x @staticmethod def _func(y, out_length, frame_length, frame_period, center, window, norm): window = Unframe._precompute( frame_length, window, norm, dtype=y.dtype, device=y.device ) return Unframe._forward(y, out_length, frame_period, center, window) @staticmethod def _precompute(length, window, norm, dtype=None, device=None): return Window._precompute( length, window, norm, dtype=dtype, device=device ).view(1, -1, 1)