import torch
from torchmetrics.retrieval import RetrievalRPrecision
metric = RetrievalRPrecision()
metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
fig_, ax_ = metric.plot()
