Learned Perceptual Image Patch Similarity (LPIPS)¶
Module Interface¶
- class torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type='alex', reduction='mean', normalize=False, **kwargs)[source]
The Learned Perceptual Image Patch Similarity (LPIPS_) is used to judge the perceptual similarity between two images. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar.
Both input image patches are expected to have shape [N, 3, H, W]. The minimum size of H, W depends on the chosen backbone (see net_type arg).
Note
using this metrics requires you to have
lpips
package installed. Either install aspip install torchmetrics[image]
orpip install lpips
Note
this metric is not scriptable when using
torch<1.8
. Please update your pytorch installation if this is a issue.- Parameters
net_type¶ (
str
) – str indicating backbone network type to use. Choose between ‘alex’, ‘vgg’ or ‘squeeze’reduction¶ (
Literal
[‘sum’, ‘mean’]) – str indicating how to reduce over the batch dimension. Choose between ‘sum’ or ‘mean’.normalize¶ (
bool
) – by default this isFalse
meaning that the input is expected to be in the [-1,1] range. If set toTrue
will instead expect input to be in the[0,1]
range.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ModuleNotFoundError – If
lpips
package is not installedValueError – If
net_type
is not one of"vgg"
,"alex"
or"squeeze"
ValueError – If
reduction
is not one of"mean"
or"sum"
Example
>>> import torch >>> _ = torch.manual_seed(123) >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') >>> # LPIPS needs the images to be in the [-1, 1] range. >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> lpips(img1, img2) tensor(0.3493, grad_fn=<SqueezeBackward0>)
Initializes internal Module state, shared by both nn.Module and ScriptModule.