Shortcuts

Signal to Distortion Ratio (SDR)

Module Interface

class torchmetrics.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]

Calculate Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • sdr (Tensor): float scalar tensor with average SDR value over samples

Parameters
  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio, 'max')
>>> pit(preds, target)
tensor(-11.6051)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]

Calculate Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.

Parameters
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

Return type

Tensor

Returns

Float tensor with shape (...,) of SDR values per sample

Raises

RuntimeError – If preds and target does not have the same shape

Example

>>> import torch
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio, 'max')
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])
Read the Docs v: latest
Versions
latest
stable
v0.11.1
v0.11.0
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.3
v0.9.2
v0.9.1
v0.9.0
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.