import torch
from torchmetrics.audio import SignalNoiseRatio
metric = SignalNoiseRatio()
metric.update(torch.rand(4), torch.rand(4))
fig_, ax_ = metric.plot()
