import torch
from torchmetrics.image import TotalVariation
metric = TotalVariation()
metric.update(torch.rand(5, 3, 28, 28))
fig_, ax_ = metric.plot()
