from torch import rand, randint
from torchmetrics.classification import BinaryFairness
metric = BinaryFairness(2)
metric.update(rand(20), randint(2,(20,)), randint(2,(20,)))
fig_, ax_ = metric.plot()
