# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
import torch
from ..typing import Precomputed
from ..utils.private import UNVOICED_SYMBOL, get_values
from .base import BaseFunctionalModule
[docs]
class MagicNumberInterpolation(BaseFunctionalModule):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/magic_intpl.html>`_
for details.
Parameters
----------
magic_number : float
The magic number to be interpolated.
"""
def __init__(self, magic_number: float = UNVOICED_SYMBOL) -> None:
super().__init__()
_, _, tensors = self._precompute(*get_values(locals()))
self.register_buffer("magic_number", tensors[0])
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Interpolate magic number.
Parameters
----------
x : Tensor [shape=(B, N, D) or (N, D) or (N,)]
The data containing magic number.
Returns
-------
out : Tensor [shape=(B, N, D) or (N, D) or (N,)]
The data after interpolation.
Examples
--------
>>> x = torch.tensor([0, 1, 2, 0, 4, 0]).float()
>>> x
tensor([0., 1., 2., 0., 4., 0.])
>>> magic_intpl = diffsptk.MagicNumberInterpolation(0)
>>> y = magic_intpl(x)
>>> y
tensor([1., 1., 2., 3., 4., 4.])
"""
return self._forward(x, **self._buffers)
@staticmethod
def _func(x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
_, _, tensors = MagicNumberInterpolation._precompute(
*args, **kwargs, device=x.device, dtype=x.dtype
)
return MagicNumberInterpolation._forward(x, *tensors)
@staticmethod
def _takes_input_size() -> bool:
return False
@staticmethod
def _check() -> None:
pass
@staticmethod
def _precompute(
magic_number: float,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Precomputed:
MagicNumberInterpolation._check()
magic_number = torch.tensor(magic_number, device=device, dtype=dtype)
return None, None, (magic_number,)
@staticmethod
def _forward(x: torch.Tensor, magic_number: torch.Tensor) -> torch.Tensor:
return MagicNumberInterpolationImpl.apply(x, magic_number)
class MagicNumberInterpolationImpl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, magic_number):
ctx.save_for_backward(x, magic_number)
# Pass through if magic number is not found.
if torch.all(x != magic_number):
return x
d = x.dim()
if d == 1:
x = x.view(1, -1, 1)
elif d == 2:
x = x.unsqueeze(0)
if x.dim() != 3:
raise ValueError("Input must be 1D, 2D, or 3D tensor.")
B, T, D = x.shape
def compute_lerp_inputs(x, magic_number):
is_magic_number = x == magic_number
starts = []
ends = []
weights = []
for i in range(x.size(0)):
uniques, counts = torch.unique_consecutive(
is_magic_number[i],
return_inverse=False,
return_counts=True,
dim=-1,
)
w = torch.repeat_interleave(uniques / (counts + 1), counts, dim=-1)
if uniques[0]:
w[..., : counts[0]] = 0
w = torch.cumsum(w, dim=-1)
w = w - torch.cummax(w * ~is_magic_number[i], dim=-1)[0]
if uniques[0]:
w[..., : counts[0]] = 1
if uniques[-1]:
w[..., -counts[-1] :] = 0
uniques, indices = torch.unique_consecutive(
x[i],
return_inverse=True,
return_counts=False,
dim=-1,
)
pos = uniques == magic_number
uniques[pos] = torch.roll(uniques, 1, dims=-1)[pos]
s = uniques[indices]
uniques[pos] = torch.roll(uniques, -1, dims=-1)[pos]
e = uniques[indices]
starts.append(s)
ends.append(e)
weights.append(w)
starts = torch.stack(starts)
ends = torch.stack(ends)
weights = torch.stack(weights)
return starts, ends, weights
x = x.transpose(-2, -1).reshape(B * D, T)
starts, ends, weights = compute_lerp_inputs(x, magic_number)
y = torch.lerp(starts, ends, weights)
y = y.reshape(B, D, T).transpose(-2, -1)
if d == 1:
y = y.view(-1)
elif d == 2:
y = y.squeeze(0)
return y
@staticmethod
def backward(ctx, grad_output):
x, magic_number = ctx.saved_tensors
return grad_output * (x != magic_number), None