Shortcuts

Binned Average Precision¶

Module Interface¶

class torchmetrics.BinnedAveragePrecision(num_classes, thresholds=100, compute_on_step=None, **kwargs)[source]

Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- vs-the-rest approach.

Computation is performed in constant-memory by computing precision and recall for thresholds buckets/thresholds (evenly distributed between 0 and 1).

Forward accepts

• preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

• target (long tensor): (N, ...) with integer labels

Parameters
Raises

ValueError – If thresholds is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = BinnedAveragePrecision(num_classes=1, thresholds=10)
>>> average_precision(pred, target)
tensor(1.0000)

Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = BinnedAveragePrecision(num_classes=5, thresholds=10)
>>> average_precision(pred, target)
[tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]


Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Returns float tensor of size n_classes.

Return type

© Copyright Copyright (c) 2020-2022, PyTorchLightning et al... Revision 60f1f185.

Built with Sphinx using a theme provided by Read the Docs.
Versions
latest
stable
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0