Source code for diffsptk.modules.snr
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #
import torch
from torch import nn
[docs]
class SignalToNoiseRatio(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/snr.html>`_
for details.
Parameters
----------
frame_length : int >= 1 or None
Frame length, :math:`L`. If given, calculate segmental SNR.
full : bool
If True, include the constant term in the SNR calculation.
reduction : ['none', 'mean', 'sum']
Reduction type.
eps : float >= 0
A small value to prevent NaN.
"""
def __init__(self, frame_length=None, full=False, reduction="mean", eps=1e-8):
super().__init__()
if frame_length is not None:
assert 1 <= frame_length
assert reduction in ("none", "mean", "sum")
assert 0 <= eps
self.frame_length = frame_length
self.full = full
self.reduction = reduction
self.eps = eps
[docs]
def forward(self, s, sn):
"""Calculate SNR.
Parameters
----------
s : Tensor [shape=(..., T)]
Signal.
sn : Tensor [shape=(..., T)]
Signal plus noise.
Returns
-------
out : Tensor [shape=(...,) or scalar]
Signal-to-noise ratio.
Examples
--------
>>> s = diffsptk.nrand(4)
>>> s
tensor([-0.5804, -0.8002, -0.0645, 0.6101, 0.4396])
>>> n = diffsptk.nrand(4) * 0.1
>>> n
tensor([ 0.0854, 0.0485, -0.0826, 0.1455, 0.0257])
>>> snr = diffsptk.SignalToNoiseRatio(full=True)
>>> y = snr(s, s + n)
>>> y
tensor(16.0614)
"""
return self._forward(
s, sn, self.frame_length, self.full, self.reduction, self.eps
)
@staticmethod
def _forward(s, sn, frame_length, full, reduction, eps):
if frame_length is not None:
s = s.unfold(-1, frame_length, frame_length)
sn = sn.unfold(-1, frame_length, frame_length)
s2 = torch.square(s).sum(-1)
n2 = torch.square(sn - s).sum(-1)
snr = torch.log10((s2 + eps) / (n2 + eps))
if frame_length is not None:
snr = snr.squeeze(-1)
if reduction == "none":
pass
elif reduction == "sum":
snr = snr.sum()
elif reduction == "mean":
snr = snr.mean()
else:
raise ValueError(f"reduction {reduction} is not supported.")
if full:
snr = 10 * snr
return snr
_func = _forward