from torchmetrics.aggregation import MeanMetric
metric = MeanMetric()
metric.update([1, 2, 3])
fig_, ax_ = metric.plot()
