Source Aggregated Signal-to-Distortion Ratio (SA-SDR)

Module Interface

class torchmetrics.audio.sdr.SourceAggregatedSignalDistortionRatio(scale_invariant=True, zero_mean=False, **kwargs)[source]

Source-aggregated signal-to-distortion ratio (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (..., spk, time)

  • target (Tensor): float tensor with shape (..., spk, time)

As output of forward and compute the metric returns the following output

  • sa_sdr (Tensor): float scalar tensor with average SA-SDR value over samples

Parameters:
  • preds – float tensor with shape (..., spk, time)

  • target – float tensor with shape (..., spk, time)

  • scale_invariant (bool) – if True, scale the targets of different speakers with the same alpha

  • zero_mean (bool) – If to zero mean target and preds or not

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000) # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> sasdr = SourceAggregatedSignalDistortionRatio()
>>> sasdr(preds, target)
tensor(-41.6579)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio,
...     mode="permutation-wise", eval_func="max")
>>> pit(preds, target)
tensor(-41.2790)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> metric.update(torch.rand(2,8000), torch.rand(2,8000))
>>> fig_, ax_ = metric.plot()
../_images/source_aggregated_signal_distortion_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(2,8000), torch.rand(2,8000)))
>>> fig_, ax_ = metric.plot(values)
../_images/source_aggregated_signal_distortion_ratio-2.png

Functional Interface

torchmetrics.functional.audio.sdr.source_aggregated_signal_distortion_ratio(preds, target, scale_invariant=True, zero_mean=False)[source]

Source-aggregated signal-to-distortion ratio (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.

Parameters:
  • preds (Tensor) – float tensor with shape (..., spk, time)

  • target (Tensor) – float tensor with shape (..., spk, time)

  • scale_invariant (bool) – if True, scale the targets of different speakers with the same alpha

  • zero_mean (bool) – If to zero mean target and preds or not

Return type:

Tensor

Returns:

SA-SDR with shape (...)

Example

>>> import torch
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000)  # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> source_aggregated_signal_distortion_ratio(preds, target)
tensor(-41.6579)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target,
...     source_aggregated_signal_distortion_ratio, mode="permutation-wise")
>>> best_metric
tensor([-37.9511, -41.9124, -42.7369, -42.5155])
>>> best_perm
tensor([[1, 0],
        [1, 0],
        [0, 1],
        [1, 0]])