Bootstrapper

Module Interface

class torchmetrics.wrappers.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', **kwargs)[source]

Using Turn a Metric into a Bootstrapped.

That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever update or forward is called, all input tensors are resampled (with replacement) along the first dimension.

Parameters:
  • base_metric (Metric) – base metric class to wrap

  • num_bootstraps (int) – number of copies to make of the base metric for bootstrapping

  • mean (bool) – if True return the mean of the bootstraps

  • std (bool) – if True return the standard deviation of the bootstraps

  • quantile (Union[float, Tensor, None]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higher

  • raw (bool) – if True, return all bootstrapped values

  • sampling_strategy (str) – Determines how to produce bootstrapped samplings. Either 'poisson' or multinomial. If 'possion' is chosen, the number of times each sample will be included in the bootstrap will be given by \(n\sim Poisson(\lambda=1)\), which approximates the true bootstrap distribution when the number of samples is large. If 'multinomial' is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example::
>>> from pprint import pprint
>>> from torchmetrics.wrappers import BootStrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> _ = torch.manual_seed(123)
>>> base_metric = MulticlassAccuracy(num_classes=5, average='micro')
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> pprint(output)
{'mean': tensor(0.2205), 'std': tensor(0.0859)}
compute()[source]

Compute the bootstrapped metric values.

Always returns a dict of tensors, which can contain the following keys: mean, std, quantile and raw depending on how the class was initialized.

Return type:

Dict[str, Tensor]

forward(*args, **kwargs)[source]

Use the original forward method of the base metric class.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import BootStrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
>>> metric.update(torch.randn(100,), torch.randn(100,))
>>> fig_, ax_ = metric.plot()
../_images/bootstrapper-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import BootStrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randn(100,), torch.randn(100,)))
>>> fig_, ax_ = metric.plot(values)
../_images/bootstrapper-2.png
update(*args, **kwargs)[source]

Update the state of the base metric.

Any tensor passed in will be bootstrapped along dimension 0.

Return type:

None