Source code for diffsptk.modules.root_pol

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import torch
from torch import nn

from ..misc.utils import check_size


[docs] class PolynomialToRoots(nn.Module): """See `this page <https://sp-nitech.github.io/sptk/latest/main/root_pol.html>`_ for details. Parameters ---------- order : int >= 1 Order of polynomial. out_format : ['rectangular', 'polar'] Output format. """ def __init__(self, order, out_format="rectangular"): super().__init__() assert 1 <= order self.order = order self.formatter = self._formatter(out_format) self.register_buffer("eye", self._precompute(self.order))
[docs] def forward(self, a): """Find roots of polynomial. Parameters ---------- a : Tensor [shape=(..., M+1)] Polynomial coefficients. Returns ------- out : Tensor [shape=(..., M)] Complex roots. Examples -------- >>> a = torch.tensor([3, 4, 5]) >>> root_pol = diffsptk.PolynomialToRoots(a.size(-1) - 1) >>> x = root_pol(a) >>> x tensor([[-0.6667+1.1055j, -0.6667-1.1055j]]) """ check_size(a.size(-1), self.order + 1, "order of polynomial") return self._forward(a, self.formatter, self.eye)
@staticmethod def _forward(a, formatter, eye): if torch.any(a[..., 0] == 0): raise RuntimeError("Leading coefficient must be non-zero.") # Make companion matrix. a = -a[..., 1:] / a[..., :1] # (..., M) E = eye.expand(a.size()[:-1] + eye.size()) A = torch.cat((a.unsqueeze(-2), E), dim=-2) # (..., M, M) # Find roots as eigenvalues. x, _ = torch.linalg.eig(A) x = formatter(x) return x @staticmethod def _func(a, out_format="rectangular"): formatter = PolynomialToRoots._formatter(out_format) eye = PolynomialToRoots._precompute( a.size(-1) - 1, dtype=a.dtype, device=a.device ) return PolynomialToRoots._forward(a, formatter, eye) @staticmethod def _precompute(order, dtype=None, device=None): return torch.eye(order - 1, order, dtype=dtype, device=device) @staticmethod def _formatter(out_format): if out_format in (0, "rectangular"): return lambda x: x elif out_format in (1, "polar"): return lambda x: torch.complex(x.abs(), x.angle()) raise ValueError(f"out_format {out_format} is not supported.")