import torch
from torchmetrics.retrieval import RetrievalMRR
metric = RetrievalMRR()
values = []
for _ in range(10):
    values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
fig, ax = metric.plot(values)
