Source code for diffsptk.misc.utils

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import logging
import math
from importlib import import_module

import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
import torchaudio
from torch import nn

UNVOICED_SYMBOL = 0
TWO_PI = math.tau


class Lambda(nn.Module):
    def __init__(self, func, **opt):
        super().__init__()
        self.func = func
        self.opt = opt

    def forward(self, x):
        return self.func(x, **self.opt)


def delayed_import(module_path, item_name):
    module = import_module(module_path)
    return getattr(module, item_name)


def get_logger(name):
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    )
    logger.handlers.clear()
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger


def get_generator(seed=None):
    generator = torch.Generator()
    if seed is not None:
        generator.manual_seed(seed)
    return generator


def is_power_of_two(n):
    return (n != 0) and (n & (n - 1) == 0)


def next_power_of_two(n):
    return 1 << (n - 1).bit_length()


def default_dtype():
    t = torch.get_default_dtype()
    if t == torch.float:
        return np.float32
    elif t == torch.double:
        return np.float64
    raise RuntimeError("Unknown default dtype: {t}.")


def default_complex_dtype():
    t = torch.get_default_dtype()
    if t == torch.float:
        return np.complex64
    elif t == torch.double:
        return np.complex128
    raise RuntimeError("Unknown default dtype: {t}.")


def torch_default_complex_dtype():
    t = torch.get_default_dtype()
    if t == torch.float:
        return torch.complex64
    elif t == torch.double:
        return torch.complex128
    raise RuntimeError("Unknown default dtype: {t}.")


def numpy_to_torch(x):
    if np.iscomplexobj(x):
        return torch.from_numpy(x.astype(default_complex_dtype()))
    else:
        return torch.from_numpy(x.astype(default_dtype()))


def to(x, dtype=None, device=None):
    if dtype is None:
        if torch.is_complex(x):
            dtype = torch_default_complex_dtype()
        else:
            dtype = torch.get_default_dtype()
    return x.to(dtype=dtype, device=device)


def to_2d(x):
    y = x.view(-1, x.size(-1))
    return y


def to_3d(x):
    y = x.view(-1, 1, x.size(-1))
    return y


def to_dataloader(x, batch_size=None):
    if torch.is_tensor(x):
        dataset = torch.utils.data.TensorDataset(x)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=len(x) if batch_size is None else batch_size,
            shuffle=False,
            drop_last=False,
        )
        return data_loader
    elif isinstance(x, torch.utils.data.DataLoader):
        return x
    else:
        raise ValueError(f"Unsupported input type: {type(x)}.")


def reflect(x):
    d = x.size(-1)
    y = x.view(-1, d)
    y = F.pad(y, (d - 1, 0), mode="reflect")
    y = y.view(*x.size()[:-1], -1)
    return y


def replicate1(x, left=True, right=True):
    d = x.size(-1)
    y = x.view(-1, d)
    y = F.pad(y, (1 if left else 0, 1 if right else 0), mode="replicate")
    y = y.view(*x.size()[:-1], -1)
    return y


def remove_gain(a, value=1, return_gain=False):
    K, a1 = torch.split(a, [1, a.size(-1) - 1], dim=-1)
    a = F.pad(a1, (1, 0), value=value)
    if return_gain:
        ret = (K, a)
    else:
        ret = a
    return ret


def get_resample_params(mode="kaiser_best"):
    # From https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html
    if mode == "kaiser_best":
        params = {
            "lowpass_filter_width": 64,
            "rolloff": 0.9475937167399596,
            "resampling_method": "sinc_interp_kaiser",
            "beta": 14.769656459379492,
        }
    elif mode == "kaiser_fast":
        params = {
            "lowpass_filter_width": 16,
            "rolloff": 0.85,
            "resampling_method": "sinc_interp_kaiser",
            "beta": 8.555504641634386,
        }
    else:
        raise ValueError("Only kaiser_best and kaiser_fast are supported.")
    return params


[docs] def get_alpha(sr, mode="hts", n_freq=10, n_alpha=100): """Compute an appropriate frequency warping factor under given sample rate. Parameters ---------- sr : int >= 1 Sample rate in Hz. mode : ['hts', 'auto'] 'hts' returns traditional alpha used in HTS. 'auto' computes appropriate alpha in L2 sense. n_freq : int >= 2 Number of sample points in the frequency domain. n_alpha : int >= 1 Number of sample points to search alpha. Returns ------- out : float in [0, 1) Frequency warping factor, :math:`\\alpha`. Examples -------- >>> _, sr = diffsptk.read("assets/data.wav") >>> alpha = diffsptk.get_alpha(sr) >>> alpha 0.42 """ def get_hts_alpha(sr): sr_to_alpha = { "8000": 0.31, "10000": 0.35, "12000": 0.37, "16000": 0.42, "22050": 0.45, "32000": 0.50, "44100": 0.53, "48000": 0.55, } key = str(int(sr)) if key not in sr_to_alpha: raise ValueError(f"Unsupported sample rate: {sr}.") selected_alpha = sr_to_alpha[key] return selected_alpha def get_auto_alpha(sr, n_freq, n_alpha): # Compute target mel-frequencies. freq = np.arange(n_freq) * (0.5 * sr / (n_freq - 1)) mel_freq = np.log1p(freq / 1000) mel_freq = mel_freq * (np.pi / mel_freq[-1]) mel_freq = np.expand_dims(mel_freq, 0) # Compute phase characteristic of the 1st order all-pass filter. alpha = np.linspace(0, 1, n_alpha, endpoint=False) alpha = np.expand_dims(alpha, 1) alpha2 = alpha * alpha omega = np.arange(n_freq) * (np.pi / (n_freq - 1)) omega = np.expand_dims(omega, 0) numer = (1 - alpha2) * np.sin(omega) denom = (1 + alpha2) * np.cos(omega) - 2 * alpha warped_omega = np.arctan(numer / denom) warped_omega[warped_omega < 0] += np.pi # Select an appropriate alpha in terms of L2 distance. distance = np.square(mel_freq - warped_omega).sum(axis=1) selected_alpha = float(np.squeeze(alpha[np.argmin(distance)])) return selected_alpha if mode == "hts": alpha = get_hts_alpha(sr) elif mode == "auto": alpha = get_auto_alpha(sr, n_freq, n_alpha) else: raise ValueError("Only hts and auto are supported.") return alpha
def get_gamma(gamma, c): if c is None or c == 0: return gamma assert 1 <= c return -1 / c def symmetric_toeplitz(x): d = x.size(-1) xx = reflect(x) X = xx.unfold(-1, d, 1).flip(-2) return X def hankel(x): d = x.size(-1) n = (d + 1) // 2 X = x.unfold(-1, n, 1)[..., :n, :] return X def vander(x): X = torch.linalg.vander(x).transpose(-2, -1) return X def cas(x): return (2**0.5) * torch.cos(x - 0.25 * torch.pi) # cos(x) + sin(x) def cexp(x): return torch.polar(torch.exp(x.real), x.imag) def clog(x): return torch.log(x.abs()) def outer(x, y=None): return torch.matmul( x.unsqueeze(-1), x.unsqueeze(-2) if y is None else y.unsqueeze(-2) ) def iir(x, b=None, a=None): if b is None: b = torch.ones(1, dtype=x.dtype, device=x.device) if a is None: a = torch.ones(1, dtype=x.dtype, device=x.device) diff = b.size(-1) - a.size(-1) if 0 < diff: a = F.pad(a, (0, diff)) elif diff < 0: b = F.pad(b, (0, -diff)) y = torchaudio.functional.lfilter(x, a, b, clamp=False, batching=True) return y def plateau(length, first, middle, last=None, dtype=None, device=None): x = torch.full((length,), middle, dtype=dtype, device=device) x[0] = first if last is not None: x[-1] = last return x def deconv1d(x, weight): """Deconvolve input. This is not transposed convolution. Parameters ---------- x : Tensor [shape=(..., T)] Input signal. weight : Tensor [shape=(M+1,)] Filter coefficients. Returns ------- out : Tensor [shape=(..., T-M)] Output signal. """ assert weight.dim() == 1 b = x.view(-1, x.size(-1)) a = weight.view(1, -1).expand(b.size(0), -1) impulse = F.pad(torch.ones_like(b[..., :1]), (0, b.size(-1) - a.size(-1))) y = iir(impulse, b, a) y = y.view(x.size()[:-1] + y.size()[-1:]) return y def check_size(x, y, cause): assert x == y, f"Unexpected {cause} (input {x} vs target {y})."
[docs] def read(filename, double=False, device=None, **kwargs): """Read waveform from file. Parameters ---------- filename : str Path of wav file. double : bool If True, return double-type tensor. device : torch.device or None Device of returned tensor. **kwargs : additional keyword arguments Additional arguments passed to `soundfile.read <https://python-soundfile.readthedocs.io/en/latest/#soundfile.read>`_. Returns ------- x : Tensor Waveform. sr : int Sample rate in Hz. Examples -------- >>> x, sr = diffsptk.read("assets/data.wav") >>> x tensor([ 0.0002, 0.0004, 0.0006, ..., 0.0006, -0.0006, -0.0007]) >>> sr 16000 """ x, sr = sf.read(filename, **kwargs) if double: x = torch.DoubleTensor(x) else: x = torch.FloatTensor(x) if device is not None: x = x.to(device) return x, sr
[docs] def write(filename, x, sr, **kwargs): """Write waveform to file. Parameters ---------- filename : str Path of wav file. x : Tensor Waveform. sr : int Sample rate in Hz. **kwargs : additional keyword arguments Additional arguments passed to `soundfile.write <https://python-soundfile.readthedocs.io/en/latest/#soundfile.write>`_. Examples -------- >>> x, sr = diffsptk.read("assets/data.wav") >>> diffsptk.write("out.wav", x, sr) """ x = x.cpu().numpy() if torch.is_tensor(x) else x sf.write(filename, x, sr, **kwargs)