Signal to Distortion Ratio (SDR)¶
Module Interface¶
- class torchmetrics.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]
Signal to Distortion Ratio (SDR) [1,2]
Forward accepts
preds
: shape[..., time]
target
: shape[..., time]
- Parameters
use_cg_iter¶ (
Optional
[int
]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires thatfast-bss-eval
is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length¶ (
int
) – The length of the distortion filter allowedzero_mean¶ (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag¶ (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zerokwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.audio import SignalDistortionRatio >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> sdr = SignalDistortionRatio() >>> sdr(preds, target) tensor(-12.0589) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(signal_distortion_ratio, 'max') >>> pit(preds, target) tensor(-11.6051)
References
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.
[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]
Signal to Distortion Ratio (SDR) [1,2]
- Parameters
use_cg_iter¶ (
Optional
[int
]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires thatfast-bss-eval
is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length¶ (
int
) – The length of the distortion filter allowedzero_mean¶ (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag¶ (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero
- Return type
- Returns
sdr value of shape
[...]
Example
>>> from torchmetrics.functional.audio import signal_distortion_ratio >>> import torch >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> signal_distortion_ratio(preds, target) tensor(-12.0589) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio, 'max') >>> best_metric tensor([-11.6375, -11.4358, -11.7148, -11.6325]) >>> best_perm tensor([[1, 0], [0, 1], [1, 0], [0, 1]])
References
[1] Vincent, E., Gribonval, R., & Fevotte, C. (2006). Performance measurement in blind audio source separation. IEEE Transactions on Audio, Speech and Language Processing, 14(4), 1462–1469.
[2] Scheibler, R. (2021). SDR – Medium Rare with Fast Computations.