Shortcuts

Permutation Invariant Training (PIT)

Module Interface

class torchmetrics.PermutationInvariantTraining(metric_func, eval_func='max', compute_on_step=None, **kwargs)[source]

Permutation invariant training (PermutationInvariantTraining). The PermutationInvariantTraining implements the famous Permutation Invariant Training method.

[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.

Forward accepts

  • preds: shape [batch, spk, ...]

  • target: shape [batch, spk, ...]

Parameters
  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...], target[:, j, ...]), and returns a batch of metric tensors [batch]

  • eval_func (str) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.

  • compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments for either the metric_func or distributed communication, see Advanced metric settings for more info.

Returns

average PermutationInvariantTraining metric

Example

>>> import torch
>>> from torchmetrics import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> pit(preds, target)
tensor(-2.1065)
Reference:

[1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average PermutationInvariantTraining metric.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Functional Interface

torchmetrics.functional.permutation_invariant_training(preds, target, metric_func, eval_func='max', **kwargs)[source]

Permutation invariant training (PIT). The permutation_invariant_training implements the famous Permutation Invariant Training method.

[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.

Parameters
  • preds (Tensor) – shape [batch, spk, ...]

  • target (Tensor) – shape [batch, spk, ...]

  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...], target[:, j, ...]), and returns a batch of metric tensors [batch]

  • eval_func (str) – the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better.

  • kwargs (Dict[str, Any]) – Additional args for metric_func

Return type

Tuple[Tensor, Tensor]

Returns

best_metric of shape [batch] best_perm of shape [batch]

Example

>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
>>> # [batch, spk, time]
>>> preds = torch.tensor([[[-0.0579,  0.3560, -0.9604], [-0.1719,  0.3205,  0.2951]]])
>>> target = torch.tensor([[[ 1.0958, -0.1648,  0.5228], [-0.4100,  1.1942, -0.5103]]])
>>> best_metric, best_perm = permutation_invariant_training(
...     preds, target, scale_invariant_signal_distortion_ratio, 'max')
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579,  0.3560, -0.9604],
         [-0.1719,  0.3205,  0.2951]]])
Reference:

[1] Permutation Invariant Training of Deep Models

Read the Docs v: v0.8.0
Versions
latest
stable
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.