import torch
from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
metric = ComplexScaleInvariantSignalNoiseRatio()
metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2))
fig_, ax_ = metric.plot()
