Shortcuts

Signal to Distortion Ratio (SDR)

Module Interface

class torchmetrics.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]

Calculates Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • sdr (Tensor): float scalar tensor with average SDR value over samples

Parameters
  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.audio import SignalDistortionRatio
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio, 'max')
>>> pit(preds, target)
tensor(-11.6051)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]

Calculates Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.

Parameters
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

Return type

Tensor

Returns

Float tensor with shape (...,) of SDR values per sample

Raises

RuntimeError – If preds and target does not have the same shape

Example

>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio, 'max')
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])