import torch
from torchmetrics.wrappers import MultioutputWrapper
from torchmetrics.regression import R2Score
metric = MultioutputWrapper(R2Score(), 2)
metric.update(torch.randn(20, 2), torch.randn(20, 2))
fig_, ax_ = metric.plot()
