Perceptual Path Length (PPL)

Module Interface

class torchmetrics.image.perceptual_path_length.PerceptualPathLength(num_samples=10000, conditional=False, batch_size=128, interpolation_method='lerp', epsilon=0.0001, resize=64, lower_discard=0.01, upper_discard=0.99, sim_net='vgg', **kwargs)[source]

Computes the perceptual path length (PPL) of a generator model.

The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is defined as

\[PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]\]

where \(G\) is the generator, \(I\) is the interpolation function, \(D\) is a similarity metric, \(z_1\) and \(z_2\) are two sets of latent points, and \(t\) is a parameter between 0 and 1. The metric thus works by interpolating between two sets of latent points, and measuring the similarity between the generated images. The expectation is approximated by sampling \(z_1\) and \(z_2\) from the generator, and averaging the calculated distanced. The similarity metric \(D\) is by default the LPIPS metric, but can be changed by setting the sim_net argument.

The provided generator model must have a sample method with signature sample(num_samples: int) -> Tensor where the returned tensor has shape (num_samples, z_size). If the generator is conditional, it must also have a num_classes attribute. The forward method of the generator must have signature forward(z: Tensor) -> Tensor if conditional=False, and forward(z: Tensor, labels: Tensor) -> Tensor if conditional=True. The returned tensor should have shape (num_samples, C, H, W) and be scaled to the range [0, 255].

Note

using this metric with the default feature extractor requires that torchvision is installed. Either install as pip install torchmetrics[image] or pip install torchvision

As input to forward and update the metric accepts the following input

  • generator (Module): Generator model, with specific requirements. See above.

As output of forward and compute the metric returns the following output

  • ppl_mean (Tensor): float scalar tensor with mean PPL value over distances

  • ppl_std (Tensor): float scalar tensor with std PPL value over distances

  • ppl_raw (Tensor): float scalar tensor with raw PPL distances

Parameters:
  • num_samples (int) – Number of samples to use for the PPL computation.

  • conditional (bool) – Whether the generator is conditional or not (i.e. whether it takes labels as input).

  • batch_size (int) – Batch size to use for the PPL computation.

  • interpolation_method (Literal['lerp', 'slerp_any', 'slerp_unit']) – Interpolation method to use. Choose from ‘lerp’, ‘slerp_any’, ‘slerp_unit’.

  • epsilon (float) – Spacing between the points on the path between latent points.

  • resize (Optional[int]) – Resize images to this size before computing the similarity between generated images.

  • lower_discard (Optional[float]) – Lower quantile to discard from the distances, before computing the mean and standard deviation.

  • upper_discard (Optional[float]) – Upper quantile to discard from the distances, before computing the mean and standard deviation.

  • sim_net (Union[Module, Literal['alex', 'vgg', 'squeeze']]) – Similarity network to use. Can be a nn.Module or one of ‘alex’, ‘vgg’, ‘squeeze’, where the three latter options correspond to the pretrained networks from the LPIPS paper.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ModuleNotFoundError – If torch-fidelity is not installed.

  • ValueError – If num_samples is not a positive integer.

  • ValueError – If conditional is not a boolean.

  • ValueError – If batch_size is not a positive integer.

  • ValueError – If interpolation_method is not one of ‘lerp’, ‘slerp_any’, ‘slerp_unit’.

  • ValueError – If epsilon is not a positive float.

  • ValueError – If resize is not a positive integer.

  • ValueError – If lower_discard is not a float between 0 and 1 or None.

  • ValueError – If upper_discard is not a float between 0 and 1 or None.

Example::
>>> from torchmetrics.image import PerceptualPathLength
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
...    def __init__(self, z_size) -> None:
...       super().__init__()
...       self.z_size = z_size
...       self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid())
...    def forward(self, z):
...       return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1)
...    def sample(self, num_samples):
...      return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> ppl = PerceptualPathLength(num_samples=10)
>>> ppl(generator)  
(tensor(0.2371),
tensor(0.1763),
tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921]))
class torchmetrics.image.perceptual_path_length.GeneratorType(*args, **kwargs)[source]

Basic interface for a generator model.

Users can inherit from this class and implement their own generator model. The requirements are that the sample method is implemented and that the num_classes attribute is present when conditional=True metric.

sample(num_samples)[source]

Sample from the generator.

Parameters:

num_samples (int) – Number of samples to generate.

Return type:

Tensor

property num_classes: int

Return the number of classes for conditional generation.

Functional Interface

torchmetrics.functional.image.perceptual_path_length.perceptual_path_length(generator, num_samples=10000, conditional=False, batch_size=64, interpolation_method='lerp', epsilon=0.0001, resize=64, lower_discard=0.01, upper_discard=0.99, sim_net='vgg', device='cpu')[source]

Computes the perceptual path length (PPL) of a generator model.

The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is defined as

\[PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]\]

where \(G\) is the generator, \(I\) is the interpolation function, \(D\) is a similarity metric, \(z_1\) and \(z_2\) are two sets of latent points, and \(t\) is a parameter between 0 and 1. The metric thus works by interpolating between two sets of latent points, and measuring the similarity between the generated images. The expectation is approximated by sampling \(z_1\) and \(z_2\) from the generator, and averaging the calculated distanced. The similarity metric \(D\) is by default the LPIPS metric, but can be changed by setting the sim_net argument.

The provided generator model must have a sample method with signature sample(num_samples: int) -> Tensor where the returned tensor has shape (num_samples, z_size). If the generator is conditional, it must also have a num_classes attribute. The forward method of the generator must have signature forward(z: Tensor) -> Tensor if conditional=False, and forward(z: Tensor, labels: Tensor) -> Tensor if conditional=True. The returned tensor should have shape (num_samples, C, H, W) and be scaled to the range [0, 255].

Parameters:
  • generator (GeneratorType) – Generator model, with specific requirements. See above.

  • num_samples (int) – Number of samples to use for the PPL computation.

  • conditional (bool) – Whether the generator is conditional or not (i.e. whether it takes labels as input).

  • batch_size (int) – Batch size to use for the PPL computation.

  • interpolation_method (Literal['lerp', 'slerp_any', 'slerp_unit']) – Interpolation method to use. Choose from ‘lerp’, ‘slerp_any’, ‘slerp_unit’.

  • epsilon (float) – Spacing between the points on the path between latent points.

  • resize (Optional[int]) – Resize images to this size before computing the similarity between generated images.

  • lower_discard (Optional[float]) – Lower quantile to discard from the distances, before computing the mean and standard deviation.

  • upper_discard (Optional[float]) – Upper quantile to discard from the distances, before computing the mean and standard deviation.

  • sim_net (Union[Module, Literal['alex', 'vgg', 'squeeze']]) – Similarity network to use. Can be a nn.Module or one of ‘alex’, ‘vgg’, ‘squeeze’, where the three latter options correspond to the pretrained networks from the LPIPS paper.

  • device (Union[str, device]) – Device to use for the computation.

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

A tuple containing the mean, standard deviation and all distances.

Example::
>>> from torchmetrics.functional.image import perceptual_path_length
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
...    def __init__(self, z_size) -> None:
...       super().__init__()
...       self.z_size = z_size
...       self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid())
...    def forward(self, z):
...       return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1)
...    def sample(self, num_samples):
...      return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> perceptual_path_length(generator, num_samples=10)  
(tensor(0.1945),
tensor(0.1222),
tensor([0.0990, 0.4173, 0.1628, 0.3573, 0.1875, 0.0335, 0.1095, 0.1887, 0.1953]))