import matplotlib.pyplot as plt
import torch

import torchmetrics

N = 10
num_updates = 10
num_steps = 5

w = torch.tensor([0.2, 0.8])
target = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)
preds = lambda it: torch.multinomial((it * w).softmax(dim=-1), 100, replacement=True)

fig, ax = plt.subplots(1, 1, figsize=(6.8, 4.8), dpi=500)

metric = torchmetrics.Accuracy(task="binary")
values = []
for step in range(num_steps):
    for _ in range(N):
        metric.update(preds(step), target(step))
    values.append(metric.compute())  # save value
    metric.reset()
metric.plot(values, ax=ax)
fig.show()
