Shortcuts

Signal to Distortion Ratio (SDR)

Module Interface

class torchmetrics.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, compute_on_step=None, **kwargs)[source]

Signal to Distortion Ratio (SDR) [1,2,3]

Forward accepts

  • preds: shape [..., time]

  • target: shape [..., time]

Parameters
  • use_cg_iter (Optional[int]) – If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

  • compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Raises

ModuleNotFoundError – If fast-bss-eval package is not installed

Example

>>> from torchmetrics.audio import SignalDistortionRatio
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio, 'max')
>>> pit(preds, target)
tensor(-11.6051)

Note

  1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes sdr to be

    non-differentiable and slower to calculate

  2. using this metrics requires you to have fast-bss-eval install. Either install as pip install torchmetrics[audio] or pip install fast-bss-eval

  3. preds and target need to have the same dtype, otherwise target will be converted to preds’ dtype

References

[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.

[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.

[3] https://github.com/fakufaku/fast_bss_eval

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average SDR.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Functional Interface

torchmetrics.functional.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]

Signal to Distortion Ratio (SDR) [1,2,3]

Parameters
  • preds (Tensor) – shape [..., time]

  • target (Tensor) – shape [..., time]

  • use_cg_iter (Optional[int]) – If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

Raises

ModuleNotFoundError – If fast-bss-eval package is not installed

Return type

Tensor

Returns

sdr value of shape [...]

Example

>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio, 'max')
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])

Note

  1. when pytorch<1.8.0, numpy will be used to calculate this metric, which causes sdr to be

    non-differentiable and slower to calculate

  2. using this metrics requires you to have fast-bss-eval install. Either install as pip install torchmetrics[audio] or pip install fast-bss-eval

  3. preds and target need to have the same dtype, otherwise target will be converted to preds’ dtype

References

[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.

[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.

[3] https://github.com/fakufaku/fast_bss_eval

Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.