Shortcuts

Perplexity

Module Interface

class torchmetrics.text.perplexity.Perplexity(ignore_index=None, **kwargs)[source]

Perplexity measures how well a language model predicts a text sample. It’s calculated as the average number of bits per word a model needs to represent the sample.

Parameters
  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> target[0, 6:] = -100
>>> metric = Perplexity(ignore_index=-100)
>>> metric(preds, target)
tensor(5.2545)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the Perplexity.

Return type

Tensor

Returns

Perplexity

update(preds, target)[source]

Compute and store intermediate statistics for Perplexity.

Parameters
  • preds (Tensor) – Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].

  • target (Tensor) – Ground truth values with a shape [batch_size, seq_len].

Return type

None

Functional Interface

torchmetrics.functional.text.perplexity.perplexity(preds, target, ignore_index=None)[source]

Perplexity measures how well a language model predicts a text sample. It’s calculated as the average number of bits per word a model needs to represent the sample.

Parameters
  • preds (Tensor) – Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].

  • target (Tensor) – Ground truth values with a shape [batch_size, seq_len].

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.

Return type

Tensor

Returns

Perplexity value

Examples

>>> import torch
>>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22))
>>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22))
>>> target[0, 6:] = -100
>>> perplexity(preds, target, ignore_index=-100)
tensor(5.2545)