Shortcuts

Signal to Distortion Ratio (SDR)

Module Interface

class torchmetrics.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]

Signal to Distortion Ratio (SDR) [1,2]

Forward accepts

  • preds: shape [..., time]

  • target: shape [..., time]

Parameters
  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.audio import SignalDistortionRatio
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio, 'max')
>>> pit(preds, target)
tensor(-11.6051)

References

[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.

[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average SDR.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Functional Interface

torchmetrics.functional.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]

Signal to Distortion Ratio (SDR) [1,2]

Parameters
  • preds (Tensor) – shape [..., time]

  • target (Tensor) – shape [..., time]

  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

Return type

Tensor

Returns

sdr value of shape [...]

Example

>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio, 'max')
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])

References

[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.

[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.