import torch
from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow
metric = RootMeanSquaredErrorUsingSlidingWindow()
metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))
fig_, ax_ = metric.plot()
