Shortcuts

Source Aggregated Signal-to-Distortion Ratio (SA-SDR)

Module Interface

class torchmetrics.audio.sdr.SourceAggregatedSignalDistortionRatio(scale_invariant=True, zero_mean=False, **kwargs)[source]

Source-aggregated signal-to-distortion ratio (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (..., spk, time)

  • target (Tensor): float tensor with shape (..., spk, time)

As output of forward and compute the metric returns the following output

  • sa_sdr (Tensor): float scalar tensor with average SA-SDR value over samples

Parameters:
  • preds – float tensor with shape (..., spk, time)

  • target – float tensor with shape (..., spk, time)

  • scale_invariant (bool) – if True, scale the targets of different speakers with the same alpha

  • zero_mean (bool) – If to zero mean target and preds or not

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000) # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> sasdr = SourceAggregatedSignalDistortionRatio()
>>> sasdr(preds, target)
tensor(-41.6579)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio,
...     mode="permutation-wise", eval_func="max")
>>> pit(preds, target)
tensor(-41.2790)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> metric.update(torch.rand(2,8000), torch.rand(2,8000))
>>> fig_, ax_ = metric.plot()
../_images/source_aggregated_signal_distortion_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(2,8000), torch.rand(2,8000)))
>>> fig_, ax_ = metric.plot(values)
../_images/source_aggregated_signal_distortion_ratio-2.png

Functional Interface

torchmetrics.functional.audio.sdr.source_aggregated_signal_distortion_ratio(preds, target, scale_invariant=True, zero_mean=False)[source]

Source-aggregated signal-to-distortion ratio (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.

Parameters:
  • preds (Tensor) – float tensor with shape (..., spk, time)

  • target (Tensor) – float tensor with shape (..., spk, time)

  • scale_invariant (bool) – if True, scale the targets of different speakers with the same alpha

  • zero_mean (bool) – If to zero mean target and preds or not

Return type:

Tensor

Returns:

SA-SDR with shape (...)

Example

>>> import torch
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000)  # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> source_aggregated_signal_distortion_ratio(preds, target)
tensor(-41.6579)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target,
...     source_aggregated_signal_distortion_ratio, mode="permutation-wise")
>>> best_metric
tensor([-37.9511, -41.9124, -42.7369, -42.5155])
>>> best_perm
tensor([[1, 0],
        [1, 0],
        [0, 1],
        [1, 0]])
Read the Docs v: stable
Versions
latest
stable
v1.1.0
v1.0.3
v1.0.2
v1.0.1
v1.0.0
v0.11.4
v0.11.3
v0.11.2
v0.11.1
v0.11.0
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.3
v0.9.2
v0.9.1
v0.9.0
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.