from torch import randn
from torchmetrics.regression import KLDivergence
metric = KLDivergence()
metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
fig_, ax_ = metric.plot()
