import torch
from torchmetrics.detection import GeneralizedIntersectionOverUnion
preds = [
   {
       "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
       "scores": torch.tensor([0.236, 0.56]),
       "labels": torch.tensor([4, 5]),
   }
]
target = [
   {
       "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
       "labels": torch.tensor([5]),
   }
]
metric = GeneralizedIntersectionOverUnion()
metric.update(preds, target)
fig_, ax_ = metric.plot()
