Shortcuts

BERT Score

Module Interface

class torchmetrics.text.bert.BERTScore(model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None, **kwargs)[source]

Bert_score Evaluating Text Generation for measuring text similarity.

BERT leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks. This implemenation follows the original implementation from BERT_score.

As input to forward and update the metric accepts the following input:

  • preds (List): An iterable of predicted sentences

  • target (List): An iterable of reference sentences

As output of forward and compute the metric returns the following output:

  • score (Dict): A dictionary containing the keys precision, recall and f1 with corresponding values

Parameters
  • preds – An iterable of predicted sentences.

  • target – An iterable of target sentences.

  • model_type – A name or a model path used to load transformers pretrained model.

  • num_layers (Optional[int]) – A layer of representation to use.

  • all_layers (bool) – An indication of whether the representation from all model’s layers should be used. If all_layers=True, the argument num_layers is ignored.

  • model (Optional[Module]) – A user’s own model. Must be of torch.nn.Module instance.

  • user_tokenizer (Optional[Any]) – A user’s own tokenizer used with the own model. This must be an instance with the __call__ method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing “input_ids” and “attention_mask” represented by Tensor. It is up to the user’s model of whether “input_ids” is a Tensor of input ids or embedding vectors. This tokenizer must prepend an equivalent of [CLS] token and append an equivalent of [SEP] token as transformers tokenizer does.

  • user_forward_fn (Optional[Callable[[Module, Dict[str, Tensor]], Tensor]]) – A user’s own forward function used in a combination with user_model. This function must take user_model and a python dictionary of containing "input_ids" and "attention_mask" represented by Tensor as an input and return the model’s output represented by the single Tensor.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.

  • idf (bool) – An indication whether normalization using inverse document frequencies should be used.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (int) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • return_hash (bool) – An indication of whether the correspodning hash_code should be returned.

  • lang (str) – A language of input sentences.

  • rescale_with_baseline (bool) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model from transformers model is used, the corresponding baseline is downloaded from the original bert-score package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score.

  • baseline_path (Optional[str]) – A path to the user’s own local csv/tsv file with the baseline scale.

  • baseline_url (Optional[str]) – A url path to the user’s own csv/tsv file with the baseline scale.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from pprint import pprint
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> bertscore = BERTScore()
>>> pprint(bertscore(preds, target))
{'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type

Tuple[Figure, Union[Axes, ndarray]]

Returns

Figure and Axes object

Raises

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> metric = BERTScore()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

(Source code, png, hires.png, pdf)

../_images/bert_score-1.png
>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> metric = BERTScore()
>>> values = []
>>> for _ in range(10):
...     val = metric(preds, target)
...     val = {k: tensor(v).mean() for k,v in val.items()}  # convert into single value per key
...     values.append(val)
>>> fig_, ax_ = metric.plot(values)

(Source code, png, hires.png, pdf)

../_images/bert_score-2.png

Functional Interface

torchmetrics.functional.text.bert.bert_score(preds, target, model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None)[source]

Bert_score Evaluating Text Generation for text similirity matching.

This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks.

This implemenation follows the original implementation from BERT_score.

Parameters
  • preds (Union[str, Sequence[str], Dict[str, Tensor]]) – Either an iterable of predicted sentences or a Dict[input_ids, attention_mask].

  • target (Union[str, Sequence[str], Dict[str, Tensor]]) – Either an iterable of target sentences or a Dict[input_ids, attention_mask].

  • model_name_or_path (Optional[str]) – A name or a model path used to load transformers pretrained model.

  • num_layers (Optional[int]) – A layer of representation to use.

  • all_layers (bool) – An indication of whether the representation from all model’s layers should be used. If all_layers = True, the argument num_layers is ignored.

  • model (Optional[Module]) – A user’s own model.

  • user_tokenizer (Optional[Any]) – A user’s own tokenizer used with the own model. This must be an instance with the __call__ method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing "input_ids" and "attention_mask" represented by Tensor. It is up to the user’s model of whether "input_ids" is a Tensor of input ids or embedding vectors. his tokenizer must prepend an equivalent of [CLS] token and append an equivalent of [SEP] token as transformers tokenizer does.

  • user_forward_fn (Optional[Callable[[Module, Dict[str, Tensor]], Tensor]]) – A user’s own forward function used in a combination with user_model. This function must take user_model and a python dictionary of containing "input_ids" and "attention_mask" represented by Tensor as an input and return the model’s output represented by the single Tensor.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.

  • idf (bool) – An indication of whether normalization using inverse document frequencies should be used.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (int) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • return_hash (bool) – An indication of whether the correspodning hash_code should be returned.

  • lang (str) – A language of input sentences. It is used when the scores are rescaled with a baseline.

  • rescale_with_baseline (bool) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model from transformers model is used, the corresponding baseline is downloaded from the original bert-score package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score

  • baseline_path (Optional[str]) – A path to the user’s own local csv/tsv file with the baseline scale.

  • baseline_url (Optional[str]) – A url path to the user’s own csv/tsv file with the baseline scale.

Return type

Dict[str, Union[Tensor, List[float], str]]

Returns

Python dictionary containing the keys precision, recall and f1 with corresponding values.

Raises

Example

>>> from pprint import pprint
>>> from torchmetrics.functional.text.bert import bert_score
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> pprint(bert_score(preds, target))
{'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
Read the Docs v: latest
Versions
latest
stable
v0.11.4
v0.11.3
v0.11.2
v0.11.1
v0.11.0
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.3
v0.9.2
v0.9.1
v0.9.0
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.