import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
num_updates = 10
num_steps = 5

w = torch.tensor([0.2, 0.8])
target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)

collection = torchmetrics.MetricCollection(
    torchmetrics.Accuracy(task="binary"),
    torchmetrics.Recall(task="binary"),
    torchmetrics.Precision(task="binary"),
)

values = []
fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500)
for step in range(num_steps):
    for _ in range(N):
        collection.update(preds(step), target(step))
    values.append(collection.compute())
    collection.reset()
collection.plot(val=values, ax=ax, together=True)
fig.tight_layout()
fig.show()
