Permutation Invariant Training (PIT)¶
Module Interface¶
- class torchmetrics.PermutationInvariantTraining(metric_func, eval_func='max', **kwargs)[source]
Calculates Permutation invariant training (PIT) that can evaluate models for speaker independent multi- talker speech separation in a permutation invariant way.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(batch_size,num_speakers,...)
target
(Tensor
): float tensor with shape(batch_size,num_speakers,...)
As output of forward and compute the metric returns the following output
pesq
(Tensor
): float scalar tensor with average PESQ value over samples
- Parameters
metric_func¶ (
Callable
) – a metric function accept a batch of target and estimate, i.e.metric_func(preds[:, i, ...], target[:, j, ...])
, and returns a batch of metric tensors(batch,)
eval_func¶ (
Literal
[‘max’, ‘min’]) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.kwargs¶ (
Any
) – Additional keyword arguments for either themetric_func
or distributed communication, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics import PermutationInvariantTraining >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio >>> _ = torch.manual_seed(42) >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') >>> pit(preds, target) tensor(-2.1065)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.permutation_invariant_training(preds, target, metric_func, eval_func='max', **kwargs)[source]
Calculates Permutation invariant training (PIT) that can evaluate models for speaker independent multi- talker speech separation in a permutation invariant way.
- Parameters
preds¶ (
Tensor
) – float tensor with shape(batch_size,num_speakers,...)
target¶ (
Tensor
) – float tensor with shape(batch_size,num_speakers,...)
metric_func¶ (
Callable
) – a metric function accept a batch of target and estimate, i.e.metric_func(preds[:, i, ...], target[:, j, ...])
, and returns a batch of metric tensors(batch,)
eval_func¶ (
Literal
[‘max’, ‘min’]) – the function to find the best permutation, can be'min'
or'max'
, i.e. the smaller the better or the larger the better.
- Return type
- Returns
Tuple of two float tensors. First tensor with shape
(batch,)
contains the best metric value for each sample and second tensor with shape(batch,)
contains the best permutation.
Example
>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = permutation_invariant_training( ... preds, target, scale_invariant_signal_distortion_ratio, 'max') >>> best_metric tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])