Shortcuts

Permutation Invariant Training (PIT)

Module Interface

class torchmetrics.PermutationInvariantTraining(metric_func, eval_func='max', **kwargs)[source]

Permutation invariant training (PermutationInvariantTraining). The PermutationInvariantTraining implements the famous Permutation Invariant Training method.

[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.

Forward accepts

  • preds: shape [batch, spk, ...]

  • target: shape [batch, spk, ...]

Parameters
  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...], target[:, j, ...]), and returns a batch of metric tensors [batch]

  • eval_func (str) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.

  • kwargs (Any) – Additional keyword arguments for either the metric_func or distributed communication, see Advanced metric settings for more info.

Returns

average PermutationInvariantTraining metric

Example

>>> import torch
>>> from torchmetrics import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> pit(preds, target)
tensor(-2.1065)
Reference:

[1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average PermutationInvariantTraining metric.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Functional Interface

torchmetrics.functional.permutation_invariant_training(preds, target, metric_func, eval_func='max', **kwargs)[source]

Permutation invariant training (PIT). The permutation_invariant_training implements the famous Permutation Invariant Training method.

[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.

Parameters
  • preds (Tensor) – shape [batch, spk, ...]

  • target (Tensor) – shape [batch, spk, ...]

  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...], target[:, j, ...]), and returns a batch of metric tensors [batch]

  • eval_func (str) – the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better.

  • kwargs (Any) – Additional args for metric_func

Return type

Tuple[Tensor, Tensor]

Returns

best_metric of shape [batch] best_perm of shape [batch]

Example

>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
>>> # [batch, spk, time]
>>> preds = torch.tensor([[[-0.0579,  0.3560, -0.9604], [-0.1719,  0.3205,  0.2951]]])
>>> target = torch.tensor([[[ 1.0958, -0.1648,  0.5228], [-0.4100,  1.1942, -0.5103]]])
>>> best_metric, best_perm = permutation_invariant_training(
...     preds, target, scale_invariant_signal_distortion_ratio, 'max')
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579,  0.3560, -0.9604],
         [-0.1719,  0.3205,  0.2951]]])
Reference:

[1] Permutation Invariant Training of Deep Models