Shortcuts

Metric Tracker

Module Interface

class torchmetrics.MetricTracker(metric, maximize=True)[source]

A wrapper class that can help keeping track of a metric or metric collection over time and implement useful methods. The wrapper implements the standard .update(), .compute(), .reset() methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:

-MetricTracker.n_steps: number of metrics being tracked -MetricTracker.increment(): initialize a new metric for being tracked -MetricTracker.compute_all(): get the metric value for all steps -MetricTracker.best_metric(): returns the best value

Parameters
  • metric (Union[Metric, MetricCollection]) – instance of a torchmetrics.Metric or torchmetrics.MetricCollection to keep track of at each timestep.

  • maximize (Union[bool, List[bool]]) – either single bool or list of bool indicating if higher metric values are better (True) or lower is better (False).

Example (single metric):
>>> from torchmetrics import Accuracy, MetricTracker
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(Accuracy(num_classes=10))
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randint(10, (100,)), torch.randint(10, (100,))
...         tracker.update(preds, target)
...     print(f"current acc={tracker.compute()}")
current acc=0.1120000034570694
current acc=0.08799999952316284
current acc=0.12600000202655792
current acc=0.07999999821186066
current acc=0.10199999809265137
>>> best_acc, which_epoch = tracker.best_metric(return_step=True)
>>> best_acc  
0.1260...
>>> which_epoch
2
>>> tracker.compute_all()
tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
Example (multiple metrics using MetricCollection):
>>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True])
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randn(100), torch.randn(100)
...         tracker.update(preds, target)
...     print(f"current stats={tracker.compute()}")  
current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)}
current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)}
current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)}
current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)}
current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)}
>>> from pprint import pprint
>>> best_res, which_epoch = tracker.best_metric(return_step=True)
>>> pprint(best_res)  
{'ExplainedVariance': -0.829...,
 'MeanSquaredError': 1.821...}
>>> which_epoch
{'MeanSquaredError': 0, 'ExplainedVariance': 2}
>>> pprint(tracker.compute_all())
{'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]),
 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

best_metric(return_step=False)[source]

Returns the highest metric out of all tracked.

Parameters

return_step (bool) – If True will also return the step with the highest metric value.

Return type

Union[None, float, Tuple[int, float], Tuple[None, None], Dict[str, Optional[float]], Tuple[Dict[str, Optional[int]], Dict[str, Optional[float]]]]

Returns

The best metric value, and optionally the time-step.

compute()[source]

Call compute of the current metric being tracked.

Return type

Any

compute_all()[source]

Compute the metric value for all tracked metrics.

Return type

Tensor

forward(*args, **kwargs)[source]

Calls forward of the current metric being tracked.

Return type

None

increment()[source]

Creates a new instance of the input metric that will be updated next.

Return type

None

reset()[source]

Resets the current metric being tracked.

Return type

None

reset_all()[source]

Resets all metrics being tracked.

Return type

None

update(*args, **kwargs)[source]

Updates the current metric being tracked.

Return type

None

property n_steps: int[source]

Returns the number of times the tracker has been incremented.

Return type

int