import torch
from torchmetrics.wrappers import MultitaskWrapper
from torchmetrics.regression import MeanSquaredError
from torchmetrics.classification import BinaryAccuracy
classification_target = torch.tensor([0, 1, 0])
regression_target = torch.tensor([2.5, 5.0, 4.0])
targets = {"Classification": classification_target, "Regression": regression_target}
classification_preds = torch.tensor([0, 0, 1])
regression_preds = torch.tensor([3.0, 5.0, 2.5])
preds = {"Classification": classification_preds, "Regression": regression_preds}
metrics = MultitaskWrapper({
    "Classification": BinaryAccuracy(),
    "Regression": MeanSquaredError()
})
metrics.update(preds, targets)
value = metrics.compute()
fig_, ax_ = metrics.plot(value)
