import torch
from torchmetrics.wrappers import MetricTracker
from torchmetrics.classification import BinaryAccuracy
tracker = MetricTracker(BinaryAccuracy())
for epoch in range(5):
    tracker.increment()
    for batch_idx in range(5):
        tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,)))
fig_, ax_ = tracker.plot()  # plot all epochs
