Signal to Distortion Ratio (SDR)

Module Interface

class torchmetrics.audio.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]

Calculate Signal to Distortion Ratio (SDR) metric.

See SDR ref1 and SDR ref2 for details on the metric.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • sdr (Tensor): float scalar tensor with average SDR value over samples

Parameters:
  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio,
...     mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-11.6051)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> metric = SignalDistortionRatio()
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()
../_images/signal_distortion_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> metric = SignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(8000), torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
../_images/signal_distortion_ratio-2.png

Functional Interface

torchmetrics.functional.audio.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]

Calculate Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

Return type:

Tensor

Returns:

Float tensor with shape (...,) of SDR values per sample

Raises:

RuntimeError – If preds and target does not have the same shape

Example

>>> import torch
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio)
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])