Signal to Distortion Ratio (SDR)¶
Module Interface¶
- class torchmetrics.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, compute_on_step=None, **kwargs)[source]
Signal to Distortion Ratio (SDR) [1,2,3]
Forward accepts
preds
: shape[..., time]
target
: shape[..., time]
- Parameters
use_cg_iter¶ (
Optional
[int
]) – If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length¶ (
int
) – The length of the distortion filter allowedzero_mean¶ (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag¶ (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zerocompute_on_step¶ (
Optional
[bool
]) –Forward only calls
update()
and returns None if this is set to False.Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.
kwargs¶ (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ModuleNotFoundError – If
fast-bss-eval
package is not installed
Example
>>> from torchmetrics.audio import SignalDistortionRatio >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> sdr = SignalDistortionRatio() >>> sdr(preds, target) tensor(-12.0589) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(signal_distortion_ratio, 'max') >>> pit(preds, target) tensor(-11.6051)
Note
- when pytorch<1.8.0, numpy will be used to calculate this metric, which causes
sdr
to be non-differentiable and slower to calculate
- when pytorch<1.8.0, numpy will be used to calculate this metric, which causes
using this metrics requires you to have
fast-bss-eval
install. Either install aspip install torchmetrics[audio]
orpip install fast-bss-eval
preds and target need to have the same dtype, otherwise target will be converted to preds’ dtype
References
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.
[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.
[3] https://github.com/fakufaku/fast_bss_eval
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]
Signal to Distortion Ratio (SDR) [1,2,3]
- Parameters
use_cg_iter¶ (
Optional
[int
]) – If provided, an iterative method is used to solve for the distortion filter coefficients instead of direct Gaussian elimination. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length¶ (
int
) – The length of the distortion filter allowedzero_mean¶ (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag¶ (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero
- Raises
ModuleNotFoundError – If
fast-bss-eval
package is not installed- Return type
- Returns
sdr value of shape
[...]
Example
>>> from torchmetrics.functional.audio import signal_distortion_ratio >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> signal_distortion_ratio(preds, target) tensor(-12.0589) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio, 'max') >>> best_metric tensor([-11.6375, -11.4358, -11.7148, -11.6325]) >>> best_perm tensor([[1, 0], [0, 1], [1, 0], [0, 1]])
Note
- when pytorch<1.8.0, numpy will be used to calculate this metric, which causes
sdr
to be non-differentiable and slower to calculate
- when pytorch<1.8.0, numpy will be used to calculate this metric, which causes
using this metrics requires you to have
fast-bss-eval
install. Either install aspip install torchmetrics[audio]
orpip install fast-bss-eval
preds and target need to have the same dtype, otherwise target will be converted to preds’ dtype
References
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.
[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.