torchmetrics.wrappers¶
Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic of the base metric.
BootStrapper¶
- class torchmetrics.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', compute_on_step=None, **kwargs)[source]
Using Turn a Metric into a Bootstrapped
That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever
update
orforward
is called, all input tensors are resampled (with replacement) along the first dimension.- Parameters
num_bootstraps¶ (
int
) – number of copies to make of the base metric for bootstrappingstd¶ (
bool
) – ifTrue
return the standard diviation of the bootstrapsquantile¶ (
Union
[float
,Tensor
,None
]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or highersampling_strategy¶ (
str
) – Determines how to produce bootstrapped samplings. Either'poisson'
ormultinomial
. If'possion'
is chosen, the number of times each sample will be included in the bootstrap will be given by , which approximates the true bootstrap distribution when the number of samples is large. If'multinomial'
is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.compute_on_step¶ (
Optional
[bool
]) –Forward only calls
update()
and returns None if this is set to False.Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.
kwargs¶ (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
- Example::
>>> from pprint import pprint >>> from torchmetrics import Accuracy, BootStrapper >>> _ = torch.manual_seed(123) >>> base_metric = Accuracy() >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) {'mean': tensor(0.2205), 'std': tensor(0.0859)}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Computes the bootstrapped metric values.
Always returns a dict of tensors, which can contain the following keys:
mean
,std
,quantile
andraw
depending on how the class was initialized.
ClasswiseWrapper¶
- class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]
Wrapper class for altering the output of classification metrics that returns multiple values to include label information.
- Parameters
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics import Accuracy, ClasswiseWrapper >>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}
- Example (labels as list of strings):
>>> import torch >>> from torchmetrics import Accuracy, ClasswiseWrapper >>> metric = ClasswiseWrapper( ... Accuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}
- Example (in metric collection):
>>> import torch >>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall >>> labels = ["horse", "fish", "dog"] >>> metric = MetricCollection( ... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), ... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Override this method to compute the final metric value from state variables synchronized across the distributed backend.
MetricTracker¶
- class torchmetrics.MetricTracker(metric, maximize=True)[source]
A wrapper class that can help keeping track of a metric or metric collection over time and implement useful methods. The wrapper implements the standard
.update()
,.compute()
,.reset()
methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:-
MetricTracker.n_steps
: number of metrics being tracked -MetricTracker.increment()
: initialize a new metric for being tracked -MetricTracker.compute_all()
: get the metric value for all steps -MetricTracker.best_metric()
: returns the best value- Parameters
- Example (single metric):
>>> from torchmetrics import Accuracy, MetricTracker >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(Accuracy(num_classes=10)) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randint(10, (100,)), torch.randint(10, (100,)) ... tracker.update(preds, target) ... print(f"current acc={tracker.compute()}") current acc=0.1120000034570694 current acc=0.08799999952316284 current acc=0.12600000202655792 current acc=0.07999999821186066 current acc=0.10199999809265137 >>> best_acc, which_epoch = tracker.best_metric(return_step=True) >>> best_acc 0.1260... >>> which_epoch 2 >>> tracker.compute_all() tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
- Example (multiple metrics using MetricCollection):
>>> from torchmetrics import MetricTracker, MetricCollection, MeanSquaredError, ExplainedVariance >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True]) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randn(100), torch.randn(100) ... tracker.update(preds, target) ... print(f"current stats={tracker.compute()}") current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)} current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)} current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)} current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)} current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} >>> from pprint import pprint >>> best_res, which_epoch = tracker.best_metric(return_step=True) >>> pprint(best_res) {'ExplainedVariance': -0.829..., 'MeanSquaredError': 1.821...} >>> which_epoch {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> pprint(tracker.compute_all()) {'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]), 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- best_metric(return_step=False)[source]
Returns the highest metric out of all tracked.
- Parameters
return_step¶ (
bool
) – IfTrue
will also return the step with the highest metric value.- Return type
Union
[None
,float
,Tuple
[int
,float
],Tuple
[None
,None
],Dict
[str
,Optional
[float
]],Tuple
[Dict
[str
,Optional
[int
]],Dict
[str
,Optional
[float
]]]]- Returns
The best metric value, and optionally the time-step.
- forward(*args, **kwargs)[source]
Calls forward of the current metric being tracked.
- Return type
- increment()[source]
Creates a new instance of the input metric that will be updated next.
- Return type
MinMaxMetric¶
- class torchmetrics.MinMaxMetric(base_metric, compute_on_step=None, **kwargs)[source]
Wrapper Metric that tracks both the minimum and maximum of a scalar/tensor across an experiment. The min/max value will be updated each time
.compute
is called.- Parameters
base_metric¶ (
Metric
) – The metric of which you want to keep track of its maximum and minimum values.compute_on_step¶ (
Optional
[bool
]) –Forward only calls
update()
and returns None if this is set to False.Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.
kwargs¶ (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
base_metric` argument is not a subclasses instance of ``torchmetrics.Metric
- Example::
>>> import torch >>> from torchmetrics import Accuracy >>> from pprint import pprint >>> base_metric = Accuracy() >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() >>> pprint(minmax_metric(preds_1, labels)) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> minmax_metric.update(preds_2, labels) >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Computes the underlying metric as well as max and min values for this metric.
Returns a dictionary that consists of the computed value (
raw
), as well as the minimum (min
) and maximum (max
) values.
- reset()[source]
Sets
max_val
andmin_val
to the initialization bounds and resets the base metric.- Return type
MultioutputWrapper¶
- class torchmetrics.MultioutputWrapper(base_metric, num_outputs, output_dim=- 1, remove_nans=True, squeeze_outputs=True)[source]
Wrap a base metric to enable it to support multiple outputs.
Several torchmetrics metrics, such as
torchmetrics.regression.spearman.SpearmanCorrcoef
lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn’t support any aggregation across outputs. This means if you setnum_outputs
to 2,.compute()
will return a Tensor of dimension(2, ...)
where...
represents the dimensions the metric returns when not wrapped.In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When
remove_nans
is passed, the class will remove the intersection of NaN containing “rows” upon each update for each output. For example, suppose a user uses MultioutputWrapper to wraptorchmetrics.regression.r2.R2Score
with 2 outputs, one of which occasionally has missing labels for classes likeR2Score
is that this class supports removingNaN
values (parameterremove_nans
) on a per-output basis. Whenremove_nans
is passed the wrapper will remove all rows- Parameters
num_outputs¶ (
int
) – Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track.output_dim¶ (
int
) – Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as Accuracy where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.remove_nans¶ (
bool
) – Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension(N, ...)
where N represents the length of the batch or dataset being passed in.squeeze_outputs¶ (
bool
) – IfTrue
, will squeeze the 1-item dimensions left afterindex_select
is applied. This is sometimes unnecessary but harmless for metrics such as R2Score but useful for certain classification metrics that can’t handle additional 1-item dimensions.
Example
>>> # Mimic R2Score in `multioutput`, `raw_values` mode: >>> import torch >>> from torchmetrics import MultioutputWrapper, R2Score >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = MultioutputWrapper(R2Score(), 2) >>> r2score(preds, target) [tensor(0.9654), tensor(0.9082)] >>> # Classification metric where prediction and label tensors have different shapes. >>> from torchmetrics import BinnedAveragePrecision >>> target = torch.tensor([[1, 2], [2, 0], [1, 2]]) >>> preds = torch.tensor([ ... [[.1, .8], [.8, .05], [.1, .15]], ... [[.1, .1], [.2, .3], [.7, .6]], ... [[.002, .4], [.95, .45], [.048, .15]] ... ]) >>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2) >>> binned_avg_precision(preds, target) [[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(*args, **kwargs)[source]
Call underlying forward methods and aggregate the results if they’re non-null.
We override this method to ensure that state variables get copied over on the underlying metrics.
- Return type