import torch
from torchmetrics.wrappers import BootStrapper
from torchmetrics.regression import MeanSquaredError
metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
metric.update(torch.randn(100,), torch.randn(100,))
fig_, ax_ = metric.plot()
