import torch
from torchmetrics.wrappers import Running
from torchmetrics.aggregation import SumMetric
metric = Running(SumMetric(), 2)
metric.update(torch.randn(20, 2))
fig_, ax_ = metric.plot()
