import torch
from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
target = preds * 0.75
metric = ErrorRelativeGlobalDimensionlessSynthesis()
metric.update(preds, target)
fig_, ax_ = metric.plot()
