Shortcuts

torchmetrics.wrappers

Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic of the base metric.

BootStrapper

class torchmetrics.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', compute_on_step=None, **kwargs)[source]

Using Turn a Metric into a Bootstrapped

That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever update or forward is called, all input tensors are resampled (with replacement) along the first dimension.

Parameters
  • base_metric (Metric) – base metric class to wrap

  • num_bootstraps (int) – number of copies to make of the base metric for bootstrapping

  • mean (bool) – if True return the mean of the bootstraps

  • std (bool) – if True return the standard diviation of the bootstraps

  • quantile (Union[float, Tensor, None]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higher

  • raw (bool) – if True, return all bootstrapped values

  • sampling_strategy (str) – Determines how to produce bootstrapped samplings. Either 'poisson' or multinomial. If 'possion' is chosen, the number of times each sample will be included in the bootstrap will be given by n\sim Poisson(\lambda=1), which approximates the true bootstrap distribution when the number of samples is large. If 'multinomial' is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.

  • compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Example::
>>> from pprint import pprint
>>> from torchmetrics import Accuracy, BootStrapper
>>> _ = torch.manual_seed(123)
>>> base_metric = Accuracy()
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> pprint(output)
{'mean': tensor(0.2205), 'std': tensor(0.0859)}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the bootstrapped metric values.

Always returns a dict of tensors, which can contain the following keys: mean, std, quantile and raw depending on how the class was initialized.

Return type

Dict[str, Tensor]

update(*args, **kwargs)[source]

Updates the state of the base metric.

Any tensor passed in will be bootstrapped along dimension 0.

Return type

None

ClasswiseWrapper

class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]

Wrapper class for altering the output of classification metrics that returns multiple values to include label information.

Parameters
  • metric (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.

  • labels (Optional[List[str]]) – list of strings indicating the different classes.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}
Example (labels as list of strings):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(
...    Accuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}
Example (in metric collection):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
...     {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels),
...     'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000),
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Dict[str, Tensor]

update(*args, **kwargs)[source]

Override this method to update the state variables of your metric class.

Return type

None

MetricTracker

class torchmetrics.MetricTracker(metric, maximize=True)[source]

A wrapper class that can help keeping track of a metric or metric collection over time and implement useful methods. The wrapper implements the standard .update(), .compute(), .reset() methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:

-MetricTracker.n_steps: number of metrics being tracked -MetricTracker.increment(): initialize a new metric for being tracked -MetricTracker.compute_all(): get the metric value for all steps -MetricTracker.best_metric(): returns the best value

Parameters
  • metric (Union[Metric, MetricCollection]) – instance of a torchmetrics.Metric or torchmetrics.MetricCollection to keep track of at each timestep.

  • maximize (Union[bool, List[bool]]) – either single bool or list of bool indicating if higher metric values are better (True) or lower is better (False).

Example (single metric):
>>> from torchmetrics import Accuracy, MetricTracker
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(Accuracy(num_classes=10))
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randint(10, (100,)), torch.randint(10, (100,))
...         tracker.update(preds, target)
...     print(f"current acc={tracker.compute()}")
current acc=0.1120000034570694
current acc=0.08799999952316284
current acc=0.12600000202655792
current acc=0.07999999821186066
current acc=0.10199999809265137
>>> best_acc, which_epoch = tracker.best_metric(return_step=True)
>>> best_acc  
0.1260...
>>> which_epoch
2
>>> tracker.compute_all()
tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
Example (multiple metrics using MetricCollection):
>>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True])
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randn(100), torch.randn(100)
...         tracker.update(preds, target)
...     print(f"current stats={tracker.compute()}")  
current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)}
current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)}
current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)}
current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)}
current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)}
>>> from pprint import pprint
>>> best_res, which_epoch = tracker.best_metric(return_step=True)
>>> pprint(best_res)  
{'ExplainedVariance': -0.829...,
 'MeanSquaredError': 1.821...}
>>> which_epoch
{'MeanSquaredError': 0, 'ExplainedVariance': 2}
>>> pprint(tracker.compute_all())
{'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]),
 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

best_metric(return_step=False)[source]

Returns the highest metric out of all tracked.

Parameters

return_step (bool) – If True will also return the step with the highest metric value.

Return type

Union[None, float, Tuple[int, float], Tuple[None, None], Dict[str, Optional[float]], Tuple[Dict[str, Optional[int]], Dict[str, Optional[float]]]]

Returns

The best metric value, and optionally the time-step.

compute()[source]

Call compute of the current metric being tracked.

Return type

Any

compute_all()[source]

Compute the metric value for all tracked metrics.

Return type

Tensor

forward(*args, **kwargs)[source]

Calls forward of the current metric being tracked.

Return type

None

increment()[source]

Creates a new instance of the input metric that will be updated next.

Return type

None

reset()[source]

Resets the current metric being tracked.

Return type

None

reset_all()[source]

Resets all metrics being tracked.

Return type

None

update(*args, **kwargs)[source]

Updates the current metric being tracked.

Return type

None

property n_steps: int[source]

Returns the number of times the tracker has been incremented.

Return type

int

MinMaxMetric

class torchmetrics.MinMaxMetric(base_metric, compute_on_step=None, **kwargs)[source]

Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max value will be updated each time .compute is called.

Parameters
  • base_metric (Metric) – The metric of which you want to keep track of its maximum and minimum values.

  • compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Raises

ValueError – If base_metric` argument is not a subclasses instance of ``torchmetrics.Metric

Example::
>>> import torch
>>> from torchmetrics import Accuracy
>>> from pprint import pprint
>>> base_metric = Accuracy()
>>> minmax_metric = MinMaxMetric(base_metric)
>>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
>>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
>>> labels = torch.Tensor([[0, 1], [0, 1]]).long()
>>> pprint(minmax_metric(preds_1, labels))
{'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
>>> pprint(minmax_metric.compute())
{'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
>>> minmax_metric.update(preds_2, labels)
>>> pprint(minmax_metric.compute())
{'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the underlying metric as well as max and min values for this metric.

Returns a dictionary that consists of the computed value (raw), as well as the minimum (min) and maximum (max) values.

Return type

Dict[str, Tensor]

reset()[source]

Sets max_val and min_val to the initialization bounds and resets the base metric.

Return type

None

update(*args, **kwargs)[source]

Updates the underlying metric.

Return type

None

MultioutputWrapper

class torchmetrics.MultioutputWrapper(base_metric, num_outputs, output_dim=- 1, remove_nans=True, squeeze_outputs=True)[source]

Wrap a base metric to enable it to support multiple outputs.

Several torchmetrics metrics, such as torchmetrics.regression.spearman.SpearmanCorrcoef lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn’t support any aggregation across outputs. This means if you set num_outputs to 2, .compute() will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped.

In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When remove_nans is passed, the class will remove the intersection of NaN containing “rows” upon each update for each output. For example, suppose a user uses MultioutputWrapper to wrap torchmetrics.regression.r2.R2Score with 2 outputs, one of which occasionally has missing labels for classes like R2Score is that this class supports removing NaN values (parameter remove_nans) on a per-output basis. When remove_nans is passed the wrapper will remove all rows

Parameters
  • base_metric (Metric) – Metric being wrapped.

  • num_outputs (int) – Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track.

  • output_dim (int) – Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as Accuracy where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.

  • remove_nans (bool) – Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension (N, ...) where N represents the length of the batch or dataset being passed in.

  • squeeze_outputs (bool) – If True, will squeeze the 1-item dimensions left after index_select is applied. This is sometimes unnecessary but harmless for metrics such as R2Score but useful for certain classification metrics that can’t handle additional 1-item dimensions.

Example

>>> # Mimic R2Score in `multioutput`, `raw_values` mode:
>>> import torch
>>> from torchmetrics import MultioutputWrapper, R2Score
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
[tensor(0.9654), tensor(0.9082)]
>>> # Classification metric where prediction and label tensors have different shapes.
>>> from torchmetrics import BinnedAveragePrecision
>>> target = torch.tensor([[1, 2], [2, 0], [1, 2]])
>>> preds = torch.tensor([
...     [[.1, .8], [.8, .05], [.1, .15]],
...     [[.1, .1], [.2, .3], [.7, .6]],
...     [[.002, .4], [.95, .45], [.048, .15]]
... ])
>>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2)
>>> binned_avg_precision(preds, target)
[[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute metrics.

Return type

List[Tensor]

forward(*args, **kwargs)[source]

Call underlying forward methods and aggregate the results if they’re non-null.

We override this method to ensure that state variables get copied over on the underlying metrics.

Return type

Any

reset()[source]

Reset all underlying metrics.

Return type

None

update(*args, **kwargs)[source]

Update each underlying metric with the corresponding output.

Return type

None