from torch import randn, randint
import torch.nn.functional as F
from torchmetrics.classification import BinaryROC
preds = F.softmax(randn(20, 2), dim=1)
target = randint(2, (20,))
metric = BinaryROC()
metric.update(preds[:, 1], target)
fig_, ax_ = metric.plot()
