Perplexity¶
Module Interface¶
- class torchmetrics.text.perplexity.Perplexity(ignore_index=None, **kwargs)[source]
Perplexity measures how well a language model predicts a text sample. It’s calculated as the average number of bits per word a model needs to represent the sample.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Probabilities assigned to each token in a sequence with shape[batch_size, seq_len, vocab_size]
target
(Tensor
): Ground truth values with a shape [batch_size, seq_len]
As output of
forward
andcompute
the metric returns the following output:perp
(Tensor
): A tensor with the perplexity score
- Parameters
Examples
>>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) >>> target[0, 6:] = -100 >>> perp = Perplexity(ignore_index=-100) >>> perp(preds, target) tensor(5.2545)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.text.perplexity.perplexity(preds, target, ignore_index=None)[source]
Perplexity measures how well a language model predicts a text sample. It’s calculated as the average number of bits per word a model needs to represent the sample.
- Parameters
preds¶ (
Tensor
) – Log probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size].target¶ (
Tensor
) – Ground truth values with a shape [batch_size, seq_len].ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.
- Return type
- Returns
Perplexity value
Examples
>>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) >>> target[0, 6:] = -100 >>> perplexity(preds, target, ignore_index=-100) tensor(5.2545)