from torchmetrics.aggregation import SumMetric
metric = SumMetric()
metric.update([1, 2, 3])
fig_, ax_ = metric.plot()
