

Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic of the base metric.


class torchmetrics.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', compute_on_step=None, **kwargs)[source]

Using Turn a Metric into a Bootstrapped

That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever update or forward is called, all input tensors are resampled (with replacement) along the first dimension.

  • base_metric (Metric) – base metric class to wrap

  • num_bootstraps (int) – number of copies to make of the base metric for bootstrapping

  • mean (bool) – if True return the mean of the bootstraps

  • std (bool) – if True return the standard diviation of the bootstraps

  • quantile (Union[float, Tensor, None]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higher

  • raw (bool) – if True, return all bootstrapped values

  • sampling_strategy (str) – Determines how to produce bootstrapped samplings. Either 'poisson' or multinomial. If 'possion' is chosen, the number of times each sample will be included in the bootstrap will be given by n\sim Poisson(\lambda=1), which approximates the true bootstrap distribution when the number of samples is large. If 'multinomial' is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.

  compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

>>> from pprint import pprint
>>> from torchmetrics import Accuracy, BootStrapper
>>> _ = torch.manual_seed(123)
>>> base_metric = Accuracy()
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> pprint(output)
{'mean': tensor(0.2205), 'std': tensor(0.0859)}

Computes the bootstrapped metric values.

Always returns a dict of tensors, which can contain the following keys: mean, std, quantile and raw depending on how the class was initialized.

Return type

Dict[str, Tensor]

update(*args, **kwargs)[source]

Updates the state of the base metric.

Any tensor passed in will be bootstrapped along dimension 0.

Return type



class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]

Wrapper class for altering the output of classification metrics that returns multiple values to include label information.

  • metric (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.

  • labels (Optional[List[str]]) – list of strings indicating the different classes.


>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}
Example (labels as list of strings):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(
...    Accuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}
Example (in metric collection):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
...     {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels),
...     'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000),
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Dict[str, Tensor]

update(*args, **kwargs)[source]

Override this method to update the state variables of your metric class.

Return type



class torchmetrics.MetricTracker(metric, maximize=True)[source]

A wrapper class that can help keeping track of a metric or metric collection over time and implement useful methods. The wrapper implements the standard .update(), .compute(), .reset() methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:

-MetricTracker.n_steps: number of metrics being tracked -MetricTracker.increment(): initialize a new metric for being tracked -MetricTracker.compute_all(): get the metric value for all steps -MetricTracker.best_metric(): returns the best value

  • metric (Union[Metric, MetricCollection]) – instance of a torchmetrics.Metric or torchmetrics.MetricCollection to keep track of at each timestep.

  • maximize (Union[bool, List[bool]]) – either single bool or list of bool indicating if higher metric values are better (True) or lower is better (False).

Example (single metric):
>>> from torchmetrics import Accuracy, MetricTracker
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(Accuracy(num_classes=10))
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randint(10, (100,)), torch.randint(10, (100,))
...         tracker.update(preds, target)
...     print(f"current acc={tracker.compute()}")
current acc=0.1120000034570694
current acc=0.08799999952316284
current acc=0.12600000202655792
current acc=0.07999999821186066
current acc=0.10199999809265137
>>> best_acc, which_epoch = tracker.best_metric(return_step=True)
>>> best_acc  
>>> which_epoch
>>> tracker.compute_all()
tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
Example (multiple metrics using MetricCollection):
>>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True])
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randn(100), torch.randn(100)
...         tracker.update(preds, target)
...     print(f"current stats={tracker.compute()}")  
current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)}
current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)}
current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)}
current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)}
current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)}
>>> from pprint import pprint
>>> best_res, which_epoch = tracker.best_metric(return_step=True)
>>> pprint(best_res)  
{'ExplainedVariance': -0.829...,
 'MeanSquaredError': 1.821...}
>>> which_epoch
{'MeanSquaredError': 0, 'ExplainedVariance': 2}
>>> pprint(tracker.compute_all())
{'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]),
 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}

Returns the highest metric out of all tracked.


return_step (bool) – If True will also return the step with the highest metric value.

Return type

Union[None, float, Tuple[int, float], Tuple[None, None], Dict[str, Optional[float]], Tuple[Dict[str, Optional[int]], Dict[str, Optional[float]]]]


The best metric value, and optionally the time-step.


Call compute of the current metric being tracked.

Return type



Compute the metric value for all tracked metrics.

Return type


forward(*args, **kwargs)[source]

Calls forward of the current metric being tracked.

Return type



Creates a new instance of the input metric that will be updated next.

Return type



Resets the current metric being tracked.

Return type



Resets all metrics being tracked.

Return type


update(*args, **kwargs)[source]

Updates the current metric being tracked.

Return type


property n_steps: int[source]

Returns the number of times the tracker has been incremented.

Return type



class torchmetrics.MinMaxMetric(base_metric, compute_on_step=None, **kwargs)[source]

Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max value will be updated each time .compute is called.

  • base_metric (Metric) – The metric of which you want to keep track of its maximum and minimum values.

  compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.


ValueError – If base_metric` argument is not a subclasses instance of ``torchmetrics.Metric

>>> import torch
>>> from torchmetrics import Accuracy
>>> from pprint import pprint
>>> base_metric = Accuracy()
>>> minmax_metric = MinMaxMetric(base_metric)
>>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
>>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
>>> labels = torch.Tensor([[0, 1], [0, 1]]).long()
>>> pprint(minmax_metric(preds_1, labels))
{'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
>>> pprint(minmax_metric.compute())
{'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
>>> minmax_metric.update(preds_2, labels)
>>> pprint(minmax_metric.compute())
{'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}

Computes the underlying metric as well as max and min values for this metric.

Returns a dictionary that consists of the computed value (raw), as well as the minimum (min) and maximum (max) values.

Return type

Dict[str, Tensor]


Sets max_val and min_val to the initialization bounds and resets the base metric.

Return type


update(*args, **kwargs)[source]

Updates the underlying metric.

Return type



class torchmetrics.MultioutputWrapper(base_metric, num_outputs, output_dim=- 1, remove_nans=True, squeeze_outputs=True)[source]

Wrap a base metric to enable it to support multiple outputs.

Several torchmetrics metrics, such as torchmetrics.regression.spearman.SpearmanCorrcoef lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn’t support any aggregation across outputs. This means if you set num_outputs to 2, .compute() will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped.

In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When remove_nans is passed, the class will remove the intersection of NaN containing “rows” upon each update for each output. For example, suppose a user uses MultioutputWrapper to wrap torchmetrics.regression.r2.R2Score with 2 outputs, one of which occasionally has missing labels for classes like R2Score is that this class supports removing NaN values (parameter remove_nans) on a per-output basis. When remove_nans is passed the wrapper will remove all rows

  • base_metric (Metric) – Metric being wrapped.

  • num_outputs (int) – Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track.

  • output_dim (int) – Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as Accuracy where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.

  • remove_nans (bool) – Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension (N, ...) where N represents the length of the batch or dataset being passed in.

  • squeeze_outputs (bool) – If True, will squeeze the 1-item dimensions left after index_select is applied. This is sometimes unnecessary but harmless for metrics such as R2Score but useful for certain classification metrics that can’t handle additional 1-item dimensions.


>>> # Mimic R2Score in `multioutput`, `raw_values` mode:
>>> import torch
>>> from torchmetrics import MultioutputWrapper, R2Score
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
[tensor(0.9654), tensor(0.9082)]
>>> # Classification metric where prediction and label tensors have different shapes.
>>> from torchmetrics import BinnedAveragePrecision
>>> target = torch.tensor([[1, 2], [2, 0], [1, 2]])
>>> preds = torch.tensor([
...     [[.1, .8], [.8, .05], [.1, .15]],
...     [[.1, .1], [.2, .3], [.7, .6]],
...     [[.002, .4], [.95, .45], [.048, .15]]
... ])
>>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2)
>>> binned_avg_precision(preds, target)
[[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]]

Initializes internal Module state, shared by both nn.Module and ScriptModule.


Compute metrics.

Return type


forward(*args, **kwargs)[source]

Call underlying forward methods and aggregate the results if they’re non-null.

We override this method to ensure that state variables get copied over on the underlying metrics.

Return type



Reset all underlying metrics.

Return type


update(*args, **kwargs)[source]

Update each underlying metric with the corresponding output.

Return type
