import torch
from torchmetrics.image.inception import InceptionScore
metric = InceptionScore()
values = [ ]
for _ in range(3):
    # we index by 0 such that only the mean value is plotted
    values.append(metric(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))[0])
fig_, ax_ = metric.plot(values)
