Source code for diffsptk.modules.df2

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import math

import torch
from torch import nn

from ..typing import Callable, Precomputed
from ..utils.private import filter_values
from .base import BaseNonFunctionalModule
from .dfs import InfiniteImpulseResponseDigitalFilter


[docs] class SecondOrderDigitalFilter(BaseNonFunctionalModule): """See `this page <https://sp-nitech.github.io/sptk/latest/main/df2.html>`_ for details. Parameters ---------- sample_rate : int >= 1 The sample rate in Hz. pole_frequency : float > 0 or None The pole frequency in Hz. pole_bandwidth : float > 0 or None The pole bandwidth in Hz. zero_frequency : float > 0 or None The zero frequency in Hz. zero_bandwidth : float > 0 or None The zero bandwidth in Hz. ir_length : int >= 1 or None The length of the truncated impulse response. If given, the filter is approximated by an FIR filter. device : torch.device or None The device of this module. dtype : torch.dtype or None The data type of this module. """ def __init__( self, sample_rate: int, pole_frequency: float | None = None, pole_bandwidth: float | None = None, zero_frequency: float | None = None, zero_bandwidth: float | None = None, ir_length: int | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: super().__init__() _, layers, _ = self._precompute(**filter_values(locals())) self.layers = nn.ModuleList(layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply a second order digital filter to the input waveform. Parameters ---------- x : Tensor [shape=(..., T)] The input waveform. Returns ------- out : Tensor [shape=(..., T)] The filtered waveform. Examples -------- >>> x = diffsptk.ramp(4) >>> df2 = diffsptk.SecondOrderDigitalFilter(16000, 100, 200, 1000, 50, 5) >>> y = df2(x) >>> y tensor([0.0000, 1.0000, 2.0918, 3.4161, 5.1021]) """ return self._forward(x, *self.layers)
@staticmethod def _check( sample_rate: int, pole_frequency: float | None, pole_bandwidth: float | None, zero_frequency: float | None, zero_bandwidth: float | None, ) -> None: if pole_frequency is not None and pole_frequency <= 0: raise ValueError("pole_frequency must be positive.") if pole_bandwidth is not None and pole_bandwidth <= 0: raise ValueError("pole_bandwidth must be positive.") if zero_frequency is not None and zero_frequency <= 0: raise ValueError("zero_frequency must be positive.") if zero_bandwidth is not None and zero_bandwidth <= 0: raise ValueError("zero_bandwidth must be positive.") nyquist = sample_rate / 2 if pole_frequency is not None and nyquist < pole_frequency: raise ValueError("pole_frequency must be less than Nyquist frequency.") if zero_frequency is not None and nyquist < zero_frequency: raise ValueError("zero_frequency must be less than Nyquist frequency.") @staticmethod def _precompute( sample_rate: int, pole_frequency: float | None, pole_bandwidth: float | None, zero_frequency: float | None, zero_bandwidth: float | None, ir_length: int | None, device: torch.device | None, dtype: torch.dtype | None, ) -> Precomputed: SecondOrderDigitalFilter._check( sample_rate, pole_frequency, pole_bandwidth, zero_frequency, zero_bandwidth ) def get_filter_coefficients(sample_rate, frequency, bandwidth): r = math.exp(-math.pi * bandwidth / sample_rate) theta = math.tau * frequency / sample_rate return [1, -2 * r * math.cos(theta), r * r] a = b = None if pole_frequency is not None: a = get_filter_coefficients(sample_rate, pole_frequency, pole_bandwidth) if zero_frequency is not None: b = get_filter_coefficients(sample_rate, zero_frequency, zero_bandwidth) dfs = InfiniteImpulseResponseDigitalFilter( a=a, b=b, ir_length=ir_length, device=device, dtype=dtype ) return None, (dfs,), None @staticmethod def _forward(x: torch.Tensor, dfs: Callable) -> torch.Tensor: return dfs(x)