Kernel Inception Distance

Module Interface

class torchmetrics.image.kid.KernelInceptionDistance(feature=2048, subsets=100, subset_size=1000, degree=3, gamma=None, coef=1.0, reset_real_features=True, normalize=False, **kwargs)[source]

Calculate Kernel Inception Distance (KID) which is used to access the quality of generated images.

\[KID = MMD(f_{real}, f_{fake})^2\]

where \(MMD\) is the maximum mean discrepancy and \(I_{real}, I_{fake}\) are extracted features from real and fake images, see kid ref1 for more details. In particular, calculating the MMD requires the evaluation of a polynomial kernel function \(k\)

\[k(x,y) = (\gamma * x^T y + coef)^{degree}\]

which controls the distance between two features. In practise the MMD is calculated over a number of subsets to be able to both get the mean and standard deviation of KID.

Using the default feature extraction (Inception v3 using the original weights from kid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape (3xHxW). If argument normalize is True images are expected to be dtype float and have values in the [0,1] range, else if normalize is set to False images are expected to have dtype uint8 and take values in the [0, 255] range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flag real determines if the images should update the statistics of the real distribution or the fake distribution.

Using custom feature extractor is also possible. One can give a torch.nn.Module as feature argument. This custom feature extractor is expected to have output shape of (1, num_features) This would change the used feature extractor from default (Inception v3) to the given network. normalize argument won’t have any effect and update method expects to have the tensor given to imgs argument to be in the correct shape and type that is compatible to the custom feature extractor.

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

As input to forward and update the metric accepts the following input

  • imgs (Tensor): tensor with images feed to the feature extractor of shape (N,C,H,W)

  • real (bool): bool indicating if imgs belong to the real or the fake distribution

As output of forward and compute the metric returns the following output

  • kid_mean (Tensor): float scalar tensor with mean value over subsets

  • kid_std (Tensor): float scalar tensor with standard deviation value over subsets

Parameters:
  • feature (Union[str, int, Module]) –

    Either an str, integer or nn.Module:

    • an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an (N,d) matrix where N is the batch size and d is the feature size.

  • subsets (int) – Number of subsets to calculate the mean and standard deviation scores over

  • subset_size (int) – Number of randomly picked samples in each subset

  • degree (int) – Degree of the polynomial kernel function

  • gamma (Optional[float]) – Scale-length of polynomial kernel. If set to None will be automatically set to the feature size

  • coef (float) – Bias term in the polynomial kernel.

  • reset_real_features (bool) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this to False if your dataset does not change.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If feature is set to an int (default settings) and torch-fidelity is not installed

  • ValueError – If feature is set to an int not in (64, 192, 768, 2048)

  • ValueError – If subsets is not an integer larger than 0

  • ValueError – If subset_size is not an integer larger than 0

  • ValueError – If degree is not an integer larger than 0

  • ValueError – If gamma is neither None or a float larger than 0

  • ValueError – If coef is not an float larger than 0

  • ValueError – If reset_real_features is not an bool

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> kid = KernelInceptionDistance(subset_size=50)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> kid.update(imgs_dist1, real=True)
>>> kid.update(imgs_dist2, real=False)
>>> kid.compute()
(tensor(0.0337), tensor(0.0023))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8)
>>> metric = KernelInceptionDistance(subsets=3, subset_size=20)
>>> metric.update(imgs_dist1, real=True)
>>> metric.update(imgs_dist2, real=False)
>>> fig_, ax_ = metric.plot()
../_images/kernel_inception_distance-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8)
>>> metric = KernelInceptionDistance(subsets=3, subset_size=20)
>>> values = [ ]
>>> for _ in range(3):
...     metric.update(imgs_dist1(), real=True)
...     metric.update(imgs_dist2(), real=False)
...     values.append(metric.compute()[0])
...     metric.reset()
>>> fig_, ax_ = metric.plot(values)
../_images/kernel_inception_distance-2.png
reset()[source]

Reset metric states.

Return type:

None