Binned Average Precision¶
Module Interface¶
- class torchmetrics.BinnedAveragePrecision(num_classes, thresholds=100, **kwargs)[source]
Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall for
thresholds
buckets/thresholds (evenly distributed between 0 and 1).Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
with integer labels
- Parameters
num_classes¶ (
int
) – integer with number of classes. Not nessesary to provide for binary problems.thresholds¶ (
Union
[int
,Tensor
,List
[float
]]) – list or tensor with specific thresholds or a number of bins from linear sampling. It is used for computation will lead to more detailed curve and accurate estimates, but will be slower and consume more memorykwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
thresholds
is not alist
ortensor
- Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision = BinnedAveragePrecision(num_classes=1, thresholds=10) >>> average_precision(pred, target) tensor(1.0000)
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> average_precision = BinnedAveragePrecision(num_classes=5, thresholds=10) >>> average_precision(pred, target) [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]
Initializes internal Module state, shared by both nn.Module and ScriptModule.