SQuAD

Module Interface

class torchmetrics.text.SQuAD(**kwargs)[source]

Calculate SQuAD Metric which is a metric for evaluating question answering models.

This metric corresponds to the scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).

As input to forward and update the metric accepts the following input:

  • preds (Dict): A Dictionary or List of Dictionary-s that map id and prediction_text to the respective values

    Example prediction:

    {"prediction_text": "TorchMetrics is awesome", "id": "123"}
    
  • target (Dict): A Dictionary or List of Dictionary-s that contain the answers and id in the SQuAD Format.

    Example target:

    {
        'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
        'id': '1',
    }
    

    Reference SQuAD Format:

    {
        'answers': {'answer_start': [1], 'text': ['This is a test text']},
        'context': 'This is a test context.',
        'id': '1',
        'question': 'Is this a test?',
        'title': 'train test'
    }
    

As output of forward and compute the metric returns the following output:

  • squad (Dict): A dictionary containing the F1 score (key: “f1”),

    and Exact match score (key: “exact_match”) for the batch.

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.text import SQuAD
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
>>> squad = SQuAD()
>>> squad(preds, target)
{'exact_match': tensor(100.), 'f1': tensor(100.)}
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import SQuAD
>>> metric = SQuAD()
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
../_images/squad-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import SQuAD
>>> metric = SQuAD()
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
../_images/squad-2.png

Functional Interface

torchmetrics.functional.text.squad(preds, target)[source]

Calculate SQuAD Metric .

Parameters:
  • preds (Union[Dict[str, str], List[Dict[str, str]]]) –

    A Dictionary or List of Dictionary-s that map id and prediction_text to the respective values.

    Example prediction:

    {"prediction_text": "TorchMetrics is awesome", "id": "123"}
    

  • target (Union[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]], List[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]]]]) –

    A Dictionary or List of Dictionary-s that contain the answers and id in the SQuAD Format.

    Example target:

    {
        'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
        'id': '1',
    }
    

    Reference SQuAD Format:

    {
        'answers': {'answer_start': [1], 'text': ['This is a test text']},
        'context': 'This is a test context.',
        'id': '1',
        'question': 'Is this a test?',
        'title': 'train test'
    }
    

Return type:

Dict[str, Tensor]

Returns:

Dictionary containing the F1 score, Exact match score for the batch.

Example

>>> from torchmetrics.functional.text.squad import squad
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
>>> squad(preds, target)
{'exact_match': tensor(100.), 'f1': tensor(100.)}
Raises:

KeyError – If the required keys are missing in either predictions or targets.

References

[1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang SQuAD Metric .