from torchmetrics.text import SQuAD
metric = SQuAD()
preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
metric.update(preds, target)
fig_, ax_ = metric.plot()
