Metric Tracker

Module Interface

class torchmetrics.wrappers.MetricTracker(metric, maximize=True)[source]

A wrapper class that can help keeping track of a metric or metric collection over time.

The wrapper implements the standard .update(), .compute(), .reset() methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:

-MetricTracker.n_steps: number of metrics being tracked -MetricTracker.increment(): initialize a new metric for being tracked -MetricTracker.compute_all(): get the metric value for all steps -MetricTracker.best_metric(): returns the best value

Out of the box, this wrapper class fully supports that the base metric being tracked is a single Metric, a MetricCollection or another MetricWrapper wrapped around a metric. However, multiple layers of nesting, such as using a Metric inside a MetricWrapper inside a MetricCollection is not fully supported, especially the .best_metric method that cannot auto compute the best metric and index for such nested structures.

Parameters:
  • metric (Union[Metric, MetricCollection]) – instance of a torchmetrics.Metric or torchmetrics.MetricCollection to keep track of at each timestep.

  • maximize (Union[bool, List[bool]]) – either single bool or list of bool indicating if higher metric values are better (True) or lower is better (False).

Example (single metric):
>>> from torchmetrics.wrappers import MetricTracker
>>> from torchmetrics.classification import MulticlassAccuracy
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MulticlassAccuracy(num_classes=10, average='micro'))
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randint(10, (100,)), torch.randint(10, (100,))
...         tracker.update(preds, target)
...     print(f"current acc={tracker.compute()}")
current acc=0.1120000034570694
current acc=0.08799999952316284
current acc=0.12600000202655792
current acc=0.07999999821186066
current acc=0.10199999809265137
>>> best_acc, which_epoch = tracker.best_metric(return_step=True)
>>> best_acc  
0.1260...
>>> which_epoch
2
>>> tracker.compute_all()
tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
Example (multiple metrics using MetricCollection):
>>> from torchmetrics.wrappers import MetricTracker
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.regression import MeanSquaredError, ExplainedVariance
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True])
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randn(100), torch.randn(100)
...         tracker.update(preds, target)
...     print(f"current stats={tracker.compute()}")  
current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)}
current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)}
current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)}
current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)}
current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)}
>>> from pprint import pprint
>>> best_res, which_epoch = tracker.best_metric(return_step=True)
>>> pprint(best_res)  
{'ExplainedVariance': -0.829...,
 'MeanSquaredError': 1.821...}
>>> which_epoch
{'MeanSquaredError': 0, 'ExplainedVariance': 2}
>>> pprint(tracker.compute_all())
{'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]),
 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}
best_metric(return_step=False)[source]

Return the highest metric out of all tracked.

Parameters:

return_step (bool) – If True will also return the step with the highest metric value.

Return type:

Union[None, float, Tuple[float, int], Tuple[None, None], Dict[str, Optional[float]], Tuple[Dict[str, Optional[float]], Dict[str, Optional[int]]]]

Returns:

Either a single value or a tuple, depends on the value of return_step and the object being tracked.

  • If a single metric is being tracked and return_step=False then a single tensor will be returned

  • If a single metric is being tracked and return_step=True then a 2-element tuple will be returned, where the first value is optimal value and second value is the corresponding optimal step

  • If a metric collection is being tracked and return_step=False then a single dict will be returned, where keys correspond to the different values of the collection and the values are the optimal metric value

  • If a metric collection is being bracked and return_step=True then a 2-element tuple will be returned where each is a dict, with keys corresponding to the different values of th collection and the values of the first dict being the optimal values and the values of the second dict being the optimal step

In addition the value in all cases may be None if the underlying metric does have a proper defined way of being optimal or in the case where a nested structure of metrics are being tracked.

compute_all()[source]

Compute the metric value for all tracked metrics.

Return type:

Any

Returns:

By default will try stacking the results from all increments into a single tensor if the tracked base object is a single metric. If a metric collection is provided a dict of stacked tensors will be returned. If the stacking process fails a list of the computed results will be returned.

Raises:

ValueError – If self.increment have not been called before this method is called.

forward(*args, **kwargs)[source]

Call forward of the current metric being tracked.

Return type:

None

increment()[source]

Create a new instance of the input metric that will be updated next.

Return type:

None

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MetricTracker
>>> from torchmetrics.classification import BinaryAccuracy
>>> tracker = MetricTracker(BinaryAccuracy())
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,)))
>>> fig_, ax_ = tracker.plot()  # plot all epochs
../_images/metric_tracker-1.png
reset()[source]

Reset the current metric being tracked.

Return type:

None

reset_all()[source]

Reset all metrics being tracked.

Return type:

None

property n_steps: int

Returns the number of times the tracker has been incremented.