from torch import tensor
from torchmetrics.detection import PanopticQuality
preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
                 [[0, 0], [0, 0], [6, 0], [0, 1]],
                 [[0, 0], [0, 0], [6, 0], [0, 1]],
                 [[0, 0], [7, 0], [6, 0], [1, 0]],
                 [[0, 0], [7, 0], [7, 0], [7, 0]]]])
target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
                  [[0, 1], [0, 1], [6, 0], [0, 1]],
                  [[0, 1], [0, 1], [6, 0], [1, 0]],
                  [[0, 1], [7, 0], [1, 0], [1, 0]],
                  [[0, 1], [7, 0], [7, 0], [7, 0]]]])
metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
metric.update(preds, target)
fig_, ax_ = metric.plot()
