Source code for diffsptk.modules.gmm

# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #

import logging

import numpy as np
import torch
from torch import nn

from ..misc.utils import check_size
from .lbg import LindeBuzoGrayAlgorithm


[docs] class GaussianMixtureModeling(nn.Module): """See `this page <https://sp-nitech.github.io/sptk/latest/main/gmm.html>`_ for details. This module is not differentiable. Parameters ---------- order : int >= 0 Order of vector. n_mixture : int >= 1 Number of mixture components. n_iter : int >= 1 Number of iterations. eps : float >= 0 Convergence threshold. weight_floor : float >= 0 Floor value for mixture weights. var_floor : float >= 0 Floor value for variance. var_type : ['diag', 'full'] Type of covariance. block_size : list[int] Block size of covariance matrix. ubm : tuple of Tensors [shape=((K,), (K, M+1), (K, M+1, M+1))] Parameters of universal background model. alpha : float in [0, 1] Smoothing parameter. verbose : bool If True, print progress. """ def __init__( self, order, n_mixture, n_iter=100, eps=1e-5, weight_floor=1e-5, var_floor=1e-6, var_type="diag", block_size=None, ubm=None, alpha=0, verbose=False, ): super().__init__() assert 0 <= order assert 1 <= n_mixture assert 1 <= n_iter assert 0 <= eps assert 0 <= weight_floor <= 1 / n_mixture assert 0 <= var_floor assert 0 <= alpha <= 1 self.order = order self.n_mixture = n_mixture self.n_iter = n_iter self.eps = eps self.weight_floor = weight_floor self.var_floor = var_floor self.alpha = alpha self.verbose = verbose if self.alpha != 0: assert ubm is not None # Check block size. L = self.order + 1 if block_size is None: block_size = [L] assert sum(block_size) == L self.is_diag = var_type == "diag" and len(block_size) == 1 # Make mask for covariance matrix. mask = torch.zeros((L, L)) cumsum = np.cumsum(np.insert(block_size, 0, 0)) for b1, s1, e1 in zip(block_size, cumsum[:-1], cumsum[1:]): assert 0 < b1 if var_type == "diag": for b2, s2, e2 in zip(block_size, cumsum[:-1], cumsum[1:]): if b1 == b2: mask[s1:e1, s2:e2] = torch.eye(b1) elif var_type == "full": mask[s1:e1, s1:e1] = 1 else: raise ValueError(f"var_type {var_type} is not supported.") self.register_buffer("mask", mask) # Initialize model parameters. K = self.n_mixture self.register_buffer("w", torch.ones(K) / K) self.register_buffer("mu", torch.randn(K, L)) self.register_buffer("sigma", torch.eye(L).repeat(K, 1, 1)) # Save UBM parameters. if ubm is not None: self.set_params(ubm) ubm_w, ubm_mu, ubm_sigma = ubm self.register_buffer("ubm_w", ubm_w) self.register_buffer("ubm_mu", ubm_mu) self.register_buffer("ubm_sigma", ubm_sigma) if self.verbose: self.logger = logging.getLogger("gmm") self.logger.setLevel(logging.INFO) formatter = logging.Formatter( "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" ) self.logger.handlers.clear() handler = logging.StreamHandler() handler.setFormatter(formatter) self.logger.addHandler(handler)
[docs] def set_params(self, params): """Set model parameters. Parameters ---------- params : tuple of Tensors [shape=((K,), (K, M+1), (K, M+1, M+1))] Parameters of Gaussian mixture model. """ w, mu, sigma = params if w is not None: self.w[:] = w if mu is not None: self.mu[:] = mu if sigma is not None: self.sigma[:] = sigma
[docs] def warmup(self, x, **lbg_params): """Initialize model parameters by K-means clustering. Parameters ---------- x : Tensor [shape=(..., M+1)] Training data. lbg_params : additional keyword arguments Parameters for Linde-Buzo-Gray algorithm. """ x = x.view(-1, x.size(-1)) lbg = LindeBuzoGrayAlgorithm(self.order, self.n_mixture, **lbg_params).to( x.device ) codebook, indices, _ = lbg(x) count = torch.bincount(indices, minlength=self.n_mixture).to(x.dtype) w = count / len(indices) mu = codebook xx = torch.matmul(x.unsqueeze(-1), x.unsqueeze(-2)) # [B, L, L] mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # [K, L, L] idx = indices.view(-1, 1, 1).expand(-1, self.order + 1, self.order + 1) kxx = torch.zeros_like(self.sigma) kxx.scatter_add_(0, idx, xx) sigma = kxx / count.view(-1, 1, 1) - mm sigma = sigma * self.mask self.set_params((w, mu, sigma))
[docs] def forward(self, x): """Train Gaussian mixture models. Parameters ---------- x : Tensor [shape=(..., M+1)] Input vectors. Returns ------- params : tuple of Tensors [shape=((K,), (K, M+1), (K, M+1, M+1))] GMM parameters. log_likelihood : Tensor [scalar] Total log-likelihood. Examples -------- >>> x = diffsptk.nrand(10, 1) >>> gmm = diffsptk.GMM(1, 2) >>> params, log_likelihood = gmm(x) >>> w, mu, sigma = params >>> w tensor([0.1917, 0.8083]) >>> mu tensor([[ 1.2321, 0.2058], [-0.1326, -0.7006]]) >>> sigma tensor([[[3.4010e-01, 0.0000e+00], [0.0000e+00, 6.2351e-04]], [[3.0944e-01, 0.0000e+00], [0.0000e+00, 8.6096e-01]]]) >>> log_likelihood tensor(-19.5235) """ check_size(x.size(-1), self.order + 1, "dimension of input") x = x.view(-1, x.size(-1)) B = x.size(0) prev_log_likelihood = -torch.inf for n in range(self.n_iter): # Compute log probabilities. log_pi = (self.order + 1) * np.log(2 * np.pi) if self.is_diag: log_det = torch.log(torch.diagonal(self.sigma, dim1=-2, dim2=-1)).sum( -1 ) # [K] precision = torch.reciprocal( torch.diagonal(self.sigma, dim1=-2, dim2=-1) ) # [K, L] diff = x.unsqueeze(1) - self.mu.unsqueeze(0) # [B, K, L] mahala = (diff**2 * precision).sum(-1) # [B, K] else: col = torch.linalg.cholesky(self.sigma) log_det = ( torch.log(torch.diagonal(col, dim1=-2, dim2=-1)).sum(-1) * 2 ) # [K] precision = torch.cholesky_inverse(col).unsqueeze(0) # [1, K, L, L] diff = x.unsqueeze(1) - self.mu.unsqueeze(0) # [B, K, L] right = torch.matmul(precision, diff.unsqueeze(-1)) # [B, K, L, 1] mahala = ( torch.matmul(diff.unsqueeze(-2), right).squeeze(-1).squeeze(-1) ) # [B, K] numer = torch.log(self.w) - 0.5 * (log_pi + log_det + mahala) # [B, K] denom = torch.logsumexp(numer, dim=-1, keepdim=True) # [B, 1] posterior = torch.exp(numer - denom) # [B, K] log_likelihood = torch.sum(denom) # Update mixture weights. if self.alpha == 0: z = posterior.sum(dim=0) self.w = z / B else: xi = self.ubm_w * self.alpha z = posterior.sum(dim=0) + xi self.w = z / (B + self.alpha) z = 1 / z self.w = torch.clamp(self.w, min=self.weight_floor) sum_floor = self.weight_floor * self.n_mixture a = (1 - sum_floor) / (self.w.sum() - sum_floor) b = self.weight_floor * (1 - a) self.w = a * self.w + b # Update mean vectors. px = torch.matmul(posterior.t(), x) if self.alpha == 0: self.mu = px * z.view(-1, 1) else: self.mu = (px + xi.view(-1, 1) * self.ubm_mu) * z.view(-1, 1) # Update covariance matrices. if self.is_diag: xx = x**2 mm = self.mu**2 pxx = torch.matmul(posterior.t(), xx) if self.alpha == 0: sigma = pxx * z.view(-1, 1) - mm else: y = posterior.sum(dim=0) nu = px / y.view(-1, 1) nn = nu**2 a = pxx - y.view(-1, 1) * (2 * nn - mm) b = xi.view(-1, 1) * self.ubm_sigma.diagonal(dim1=-2, dim2=-1) diff = self.ubm_mu - self.mu dd = diff**2 c = xi.view(-1, 1) * dd sigma = (a + b + c) * z.view(-1, 1) self.sigma.diagonal(dim1=-2, dim2=-1).copy_(sigma) else: xx = torch.matmul(x.unsqueeze(-1), x.unsqueeze(-2)) mm = torch.matmul(self.mu.unsqueeze(-1), self.mu.unsqueeze(-2)) pxx = torch.einsum("bk,blm->klm", posterior, xx) if self.alpha == 0: sigma = pxx * z.view(-1, 1, 1) - mm else: y = posterior.sum(dim=0) nu = px / y.view(-1, 1) nm = torch.matmul(nu.unsqueeze(-1), self.mu.unsqueeze(-2)) mn = nm.transpose(-2, -1) a = pxx - y.view(-1, 1, 1) * (nm + mn - mm) b = xi.view(-1, 1, 1) * self.ubm_sigma diff = self.ubm_mu - self.mu dd = torch.matmul(diff.unsqueeze(-1), diff.unsqueeze(-2)) c = xi.view(-1, 1, 1) * dd sigma = (a + b + c) * z.view(-1, 1, 1) self.sigma = sigma * self.mask self.sigma.diagonal(dim1=-2, dim2=-1).clamp_(min=self.var_floor) # Check convergence. change = log_likelihood - prev_log_likelihood if self.verbose: self.logger.info(f"iter {n+1:5d}: average = {log_likelihood / B:g}") if n and change < self.eps: break prev_log_likelihood = log_likelihood params = [self.w, self.mu, self.sigma] return params, log_likelihood