Permutation Invariant Training (PIT)¶
Module Interface¶
- class torchmetrics.PermutationInvariantTraining(metric_func, eval_func='max', **kwargs)[source]
Permutation invariant training (PermutationInvariantTraining). The PermutationInvariantTraining implements the famous Permutation Invariant Training method.
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
Forward accepts
preds
:shape [batch, spk, ...]
target
:shape [batch, spk, ...]
- Parameters
metric_func¶ (
Callable
) – a metric function accept a batch of target and estimate, i.e.metric_func(preds[:, i, ...], target[:, j, ...])
, and returns a batch of metric tensors[batch]
eval_func¶ (
str
) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.kwargs¶ (
Any
) – Additional keyword arguments for either themetric_func
or distributed communication, see Advanced metric settings for more info.
- Returns
average PermutationInvariantTraining metric
Example
>>> import torch >>> from torchmetrics import PermutationInvariantTraining >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio >>> _ = torch.manual_seed(42) >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') >>> pit(preds, target) tensor(-2.1065)
- Reference:
[1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.permutation_invariant_training(preds, target, metric_func, eval_func='max', **kwargs)[source]
Permutation invariant training (PIT). The
permutation_invariant_training
implements the famous Permutation Invariant Training method.[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
- Parameters
metric_func¶ (
Callable
) – a metric function accept a batch of target and estimate, i.e.metric_func(preds[:, i, ...], target[:, j, ...])
, and returns a batch of metric tensors[batch]
eval_func¶ (
str
) – the function to find the best permutation, can be'min'
or'max'
, i.e. the smaller the better or the larger the better.
- Return type
- Returns
best_metric of shape
[batch]
best_perm of shape[batch]
Example
>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = permutation_invariant_training( ... preds, target, scale_invariant_signal_distortion_ratio, 'max') >>> best_metric tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
- Reference: