import torch
from torchmetrics.detection import GeneralizedIntersectionOverUnion
preds = [
   {
       "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
       "scores": torch.tensor([0.236, 0.56]),
       "labels": torch.tensor([4, 5]),
   }
]
target = lambda : [
   {
       "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)),
       "labels": torch.tensor([5]),
   }
]
metric = GeneralizedIntersectionOverUnion()
vals = []
for _ in range(20):
    vals.append(metric(preds, target()))
fig_, ax_ = metric.plot(vals)
