Shortcuts

Permutation Invariant Training (PIT)

Module Interface

class torchmetrics.PermutationInvariantTraining(metric_func, eval_func='max', **kwargs)[source]

Calculates Permutation invariant training (PIT) that can evaluate models for speaker independent multi- talker speech separation in a permutation invariant way.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (batch_size,num_speakers,...)

  • target (Tensor): float tensor with shape (batch_size,num_speakers,...)

As output of forward and compute the metric returns the following output

  • pesq (Tensor): float scalar tensor with average PESQ value over samples

Parameters
  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...], target[:, j, ...]), and returns a batch of metric tensors (batch,)

  • eval_func (Literal[‘max’, ‘min’]) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.

  • kwargs (Any) – Additional keyword arguments for either the metric_func or distributed communication, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> pit(preds, target)
tensor(-2.1065)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.permutation_invariant_training(preds, target, metric_func, eval_func='max', **kwargs)[source]

Calculates Permutation invariant training (PIT) that can evaluate models for speaker independent multi- talker speech separation in a permutation invariant way.

Parameters
  • preds (Tensor) – float tensor with shape (batch_size,num_speakers,...)

  • target (Tensor) – float tensor with shape (batch_size,num_speakers,...)

  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, ...], target[:, j, ...]), and returns a batch of metric tensors (batch,)

  • eval_func (Literal[‘max’, ‘min’]) – the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better.

  • kwargs (Any) – Additional args for metric_func

Return type

Tuple[Tensor, Tensor]

Returns

Tuple of two float tensors. First tensor with shape (batch,) contains the best metric value for each sample and second tensor with shape (batch,) contains the best permutation.

Example

>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
>>> # [batch, spk, time]
>>> preds = torch.tensor([[[-0.0579,  0.3560, -0.9604], [-0.1719,  0.3205,  0.2951]]])
>>> target = torch.tensor([[[ 1.0958, -0.1648,  0.5228], [-0.4100,  1.1942, -0.5103]]])
>>> best_metric, best_perm = permutation_invariant_training(
...     preds, target, scale_invariant_signal_distortion_ratio, 'max')
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579,  0.3560, -0.9604],
         [-0.1719,  0.3205,  0.2951]]])
Read the Docs v: stable
Versions
latest
stable
v0.11.4
v0.11.3
v0.11.2
v0.11.1
v0.11.0
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.3
v0.9.2
v0.9.1
v0.9.0
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.