import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
num_updates = 10
num_steps = 5

w = torch.tensor([0.2, 0.8])
target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
preds = lambda it: (it * torch.randn(100)).sigmoid()

confmat = torchmetrics.ConfusionMatrix(task="binary")
roc = torchmetrics.ROC(task="binary")
tracker = torchmetrics.wrappers.MetricTracker(
    torchmetrics.MetricCollection(
        torchmetrics.Accuracy(task="binary"),
        torchmetrics.Recall(task="binary"),
        torchmetrics.Precision(task="binary"),
        confmat,
        roc,
    )
)

fig = plt.figure(layout="constrained", figsize=(6.8, 4.8), dpi=500)
ax1 = plt.subplot(2, 2, 1)
ax2 = plt.subplot(2, 2, 2)
ax3 = plt.subplot(2, 2, (3, 4))

for step in range(num_steps):
    tracker.increment()
    for _ in range(N):
        tracker.update(preds(step), target(step))

# get the results from all steps and extract for confusion matrix and roc
all_results = tracker.compute_all()
confmat.plot(val=all_results[-1]["BinaryConfusionMatrix"], ax=ax1)
roc.plot(all_results[-1]["BinaryROC"], ax=ax2)

scalar_results = [{k: v for k, v in ar.items() if isinstance(v, torch.Tensor) and v.numel() == 1} for ar in all_results]

tracker.plot(val=scalar_results, ax=ax3)
fig.show()
