# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group                                        #
#                                                                          #
# Licensed under the Apache License, Version 2.0 (the "License");          #
# you may not use this file except in compliance with the License.         #
# You may obtain a copy of the License at                                  #
#                                                                          #
#     http://www.apache.org/licenses/LICENSE-2.0                           #
#                                                                          #
# Unless required by applicable law or agreed to in writing, software      #
# distributed under the License is distributed on an "AS IS" BASIS,        #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and      #
# limitations under the License.                                           #
# ------------------------------------------------------------------------ #
import torch
from torch import nn
from ..misc.utils import to
[docs]
class Histogram(nn.Module):
    """See `this page <https://sp-nitech.github.io/sptk/latest/main/histogram.html>`_
    for details.
    Parameters
    ----------
    n_bin : int >= 1
        Number of bins, :math:`K`.
    lower_bound : float < U
        Lower bound of the histogram, :math:`L`.
    upper_bound : float > L
        Upper bound of the histogram, :math:`U`.
    norm : bool
        If True, normalize the histogram.
    softness : float > 0
        A smoothing parameter. The smaller value makes the output closer to the true
        histogram, but the gradient vanishes.
    References
    ----------
    .. [1] M. Avi-Aharon et al., "DeepHist: Differentiable joint and color histogram
           layers for image-to-image translation," *arXiv preprint arXiv:2005.03995*,
           2020.
    """
    def __init__(
        self, n_bin=10, lower_bound=0, upper_bound=1, norm=False, softness=1e-3
    ):
        super().__init__()
        assert 1 <= n_bin
        assert lower_bound < upper_bound
        assert 0 < softness
        self.norm = norm
        self.softness = softness
        centers = self._precompute(n_bin, lower_bound, upper_bound)
        self.register_buffer("centers", centers)
[docs]
    def forward(self, x):
        """Compute histogram.
        Parameters
        ----------
        x : Tensor [shape=(..., T)]
            Input data.
        Returns
        -------
        out : Tensor [shape=(..., K)]
            Histogram.
        Examples
        --------
        >>> x = diffsptk.ramp(9)
        >>> histogram = diffsptk.Histogram(n_bin=4, lower_bound=-0.1, upper_bound=9.1)
        >>> h = histogram(x)
        >>> h
        tensor([3., 2., 2., 3.])
        """
        return self._forward(x, self.norm, self.softness, self.centers) 
    @staticmethod
    def _forward(x, norm, softness, centers):
        y = x.unsqueeze(-2) - centers.unsqueeze(-1)  # (..., K, T)
        g = 0.5 * (centers[1] - centers[0])
        h = torch.sigmoid((y + g) / softness) - torch.sigmoid((y - g) / softness)
        h = h.sum(-1)
        if norm:
            h /= h.sum(-1, keepdim=True)
        return h
    @staticmethod
    def _func(x, n_bin, lower_bound, upper_bound, norm, softness):
        centers = Histogram._precompute(
            n_bin, lower_bound, upper_bound, dtype=x.dtype, device=x.device
        )
        return Histogram._forward(x, norm, softness, centers)
    @staticmethod
    def _precompute(n_bin, lower_bound, upper_bound, dtype=None, device=None):
        width = (upper_bound - lower_bound) / n_bin
        bias = lower_bound + 0.5 * width
        centers = torch.arange(n_bin, dtype=torch.double, device=device) * width + bias
        return to(centers, dtype=dtype)