Source code for diffsptk.modules.lbg

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import logging

import torch
from torch import nn

from ..misc.utils import check_size
from ..misc.utils import is_power_of_two
from .vq import VectorQuantization


[docs] class LindeBuzoGrayAlgorithm(nn.Module): """See `this page <https://sp-nitech.github.io/sptk/latest/main/lbg.html>`_ for details. This module is not differentiable. Parameters ---------- order : int >= 0 Order of vector. codebook_size : int >= 1 Target codebook size, must be power of two. min_data_per_cluster : int >= 1 Minimum number of data points in a cluster. n_iter : int >= 1 Number of iterations. eps : float >= 0 Convergence threshold. perturb_factor : float > 0 Perturbation factor. verbose : bool If True, print progress. """ def __init__( self, order, codebook_size, min_data_per_cluster=1, n_iter=100, eps=1e-5, perturb_factor=1e-5, verbose=False, ): super().__init__() assert 0 <= order assert is_power_of_two(codebook_size) assert 1 <= min_data_per_cluster assert 1 <= n_iter assert 0 <= eps assert 0 < perturb_factor self.order = order self.codebook_size = codebook_size self.min_data_per_cluster = min_data_per_cluster self.n_iter = n_iter self.eps = eps self.perturb_factor = perturb_factor self.verbose = verbose self.vq = VectorQuantization(order, codebook_size).eval() self.vq.codebook[:] = 1e10 if self.verbose: self.logger = logging.getLogger("lbg") self.logger.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" ) self.logger.handlers.clear() handler = logging.StreamHandler() handler.setFormatter(formatter) self.logger.addHandler(handler)
[docs] def forward(self, x): """Design a codebook. Parameters ---------- x : Tensor [shape=(..., M+1)] Input vectors. Returns ------- codebook : Tensor [shape=(K, M+1)] Codebook. indices : Tensor [shape=(...,)] Codebook indices. distance : Tensor [scalar] Distance. Examples -------- >>> x = diffsptk.nrand(10, 0) >>> lbg = diffsptk.LBG(0, 2) >>> codebook, indices, distance = lbg(x) >>> codebook tensor([[-0.5277], [ 0.6747]]) >>> indices tensor([0, 0, 0, 1, 0, 1, 1, 1, 1, 0]) >>> distance tensor(0.2331) """ check_size(x.size(-1), self.order + 1, "dimension of input") # Initalize codebook. x = x.view(-1, x.size(-1)) mean = x.mean(0) self.vq.codebook[0] = mean distance = torch.inf curr_codebook_size = 1 next_codebook_size = 2 while next_codebook_size <= self.codebook_size: # Double codebook. codebook = self.vq.codebook[:curr_codebook_size] r = torch.randn_like(codebook) * self.perturb_factor self.vq.codebook[curr_codebook_size:next_codebook_size] = codebook - r self.vq.codebook[:curr_codebook_size] += r curr_codebook_size = next_codebook_size next_codebook_size *= 2 if self.verbose: self.logger.info(f"K = {curr_codebook_size}") prev_distance = distance # Suppress flake8 warnings. for n in range(self.n_iter): # E-step: evaluate model. xq, indices, _ = self.vq(x) distance = (x - xq).square().sum() distance /= x.size(0) if self.verbose: self.logger.info(f"iter {n+1:5d}: distance = {distance:g}") # Check convergence. change = (prev_distance - distance).abs() if n and change / (distance + 1e-16) < self.eps: break prev_distance = distance # Get number of data points for each cluster. n_data = torch.histc( indices.float(), bins=curr_codebook_size, min=0, max=curr_codebook_size - 1, ) mask = self.min_data_per_cluster <= n_data # M-step: update centroids. centroids = torch.zeros( (curr_codebook_size, self.order + 1), dtype=x.dtype, device=x.device ) idx = indices.unsqueeze(1).expand(-1, self.order + 1) centroids.scatter_add_(0, idx, x) centroids[mask] /= n_data[mask].unsqueeze(1) if torch.any(~mask): # Get index of largest cluster. _, m = n_data.max(0) copied_centroids = centroids[m : m + 1].expand((~mask).sum(), -1) r = torch.randn_like(copied_centroids) * self.perturb_factor centroids[~mask] = copied_centroids - r centroids[m] += r.mean(0) self.vq.codebook[:curr_codebook_size] = centroids _, indices, _ = self.vq(x) return self.vq.codebook, indices, distance