import torch
from torchmetrics.image import RelativeAverageSpectralError
metric = RelativeAverageSpectralError()
metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))
fig_, ax_ = metric.plot()
