from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
import torch
preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
target = preds * 0.75
metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
metric.update(preds, target)
fig_, ax_ = metric.plot()
