import torch
from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
metric = PeakSignalNoiseRatioWithBlockedEffect()
metric.update(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10))
fig_, ax_ = metric.plot()
