Shortcuts

Binned Average Precision

Module Interface

class torchmetrics.BinnedAveragePrecision(num_classes, thresholds=None, compute_on_step=None, **kwargs)[source]

Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- vs-the-rest approach.

Computation is performed in constant-memory by computing precision and recall for thresholds buckets/thresholds (evenly distributed between 0 and 1).

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) with integer labels

Parameters
  • num_classes (int) – integer with number of classes. Not nessesary to provide for binary problems.

  • thresholds (Union[int, Tensor, List[float], None]) – list or tensor with specific thresholds or a number of bins from linear sampling. It is used for computation will lead to more detailed curve and accurate estimates, but will be slower and consume more memory

  • compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Raises

ValueError – If thresholds is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = BinnedAveragePrecision(num_classes=1, thresholds=10)
>>> average_precision(pred, target)
tensor(1.0000)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = BinnedAveragePrecision(num_classes=5, thresholds=10)
>>> average_precision(pred, target)
[tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Returns float tensor of size n_classes.

Return type

Union[List[Tensor], Tensor]