import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
fig_, ax_ = metric.plot()
