Metric Tracker¶
Module Interface¶
- class torchmetrics.MetricTracker(metric, maximize=True)[source]
A wrapper class that can help keeping track of a metric or metric collection over time and implement useful methods. The wrapper implements the standard
.update()
,.compute()
,.reset()
methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:-
MetricTracker.n_steps
: number of metrics being tracked -MetricTracker.increment()
: initialize a new metric for being tracked -MetricTracker.compute_all()
: get the metric value for all steps -MetricTracker.best_metric()
: returns the best value- Parameters
- Example (single metric):
>>> from torchmetrics import MetricTracker >>> from torchmetrics.classification import MulticlassAccuracy >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MulticlassAccuracy(num_classes=10, average='micro')) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randint(10, (100,)), torch.randint(10, (100,)) ... tracker.update(preds, target) ... print(f"current acc={tracker.compute()}") current acc=0.1120000034570694 current acc=0.08799999952316284 current acc=0.12600000202655792 current acc=0.07999999821186066 current acc=0.10199999809265137 >>> best_acc, which_epoch = tracker.best_metric(return_step=True) >>> best_acc 0.1260... >>> which_epoch 2 >>> tracker.compute_all() tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
- Example (multiple metrics using MetricCollection):
>>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True]) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randn(100), torch.randn(100) ... tracker.update(preds, target) ... print(f"current stats={tracker.compute()}") current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)} current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)} current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)} current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)} current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} >>> from pprint import pprint >>> best_res, which_epoch = tracker.best_metric(return_step=True) >>> pprint(best_res) {'ExplainedVariance': -0.829..., 'MeanSquaredError': 1.821...} >>> which_epoch {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> pprint(tracker.compute_all()) {'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]), 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- best_metric(return_step=False)[source]
Returns the highest metric out of all tracked.
- Parameters
return_step¶ (
bool
) – IfTrue
will also return the step with the highest metric value.- Return type
Union
[None
,float
,Tuple
[float
,int
],Tuple
[None
,None
],Dict
[str
,Optional
[float
]],Tuple
[Dict
[str
,Optional
[float
]],Dict
[str
,Optional
[int
]]]]- Returns
Either a single value or a tuple, depends on the value of
return_step
and the object being tracked.If a single metric is being tracked and
return_step=False
then a single tensor will be returnedIf a single metric is being tracked and
return_step=True
then a 2-element tuple will be returned, where the first value is optimal value and second value is the corresponding optimal stepIf a metric collection is being tracked and
return_step=False
then a single dict will be returned, where keys correspond to the different values of the collection and the values are the optimal metric valueIf a metric collection is being bracked and
return_step=True
then a 2-element tuple will be returned where each is a dict, with keys corresponding to the different values of th collection and the values of the first dict being the optimal values and the values of the second dict being the optimal step
In addtion the value in all cases may be
None
if the underlying metric does have a proper defined way of being optimal.
- compute_all()[source]
Compute the metric value for all tracked metrics.
- forward(*args, **kwargs)[source]
Calls forward of the current metric being tracked.
- Return type
- increment()[source]
Creates a new instance of the input metric that will be updated next.
- Return type