from torch import tensor
from torchmetrics.detection import ModifiedPanopticQuality
preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
                 [[0, 0], [0, 0], [6, 0], [0, 1]],
                 [[0, 0], [0, 0], [6, 0], [0, 1]],
                 [[0, 0], [7, 0], [6, 0], [1, 0]],
                 [[0, 0], [7, 0], [7, 0], [7, 0]]]])
target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
                  [[0, 1], [0, 1], [6, 0], [0, 1]],
                  [[0, 1], [0, 1], [6, 0], [1, 0]],
                  [[0, 1], [7, 0], [1, 0], [1, 0]],
                  [[0, 1], [7, 0], [7, 0], [7, 0]]]])
metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
vals = []
for _ in range(20):
    vals.append(metric(preds, target))
fig_, ax_ = metric.plot(vals)
