import torch
from torchmetrics.image.kid import KernelInceptionDistance
imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8)
metric = KernelInceptionDistance(subsets=3, subset_size=20)
metric.update(imgs_dist1, real=True)
metric.update(imgs_dist2, real=False)
fig_, ax_ = metric.plot()
