import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
metrics.update(torch.rand(10), torch.randint(2, (10,)))
fig_ax_ = metrics.plot()
