import torch
from torchmetrics.image.inception import InceptionScore
metric = InceptionScore()
metric.update(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))
fig_, ax_ = metric.plot()  # the returned plot only shows the mean value by default
