from torch import randint
from torchmetrics.classification import MulticlassConfusionMatrix
metric = MulticlassConfusionMatrix(num_classes=5)
metric.update(randint(5, (20,)), randint(5, (20,)))
fig_, ax_ = metric.plot()
