from torchmetrics.aggregation import MaxMetric
metric = MaxMetric()
metric.update([1, 2, 3])
fig_, ax_ = metric.plot()
