Shortcuts

Average Precision

Module Interface

class torchmetrics.AveragePrecision(num_classes=None, pos_label=None, average='macro', **kwargs)[source]

Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- vs-the-rest approach.

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) with integer labels

Parameters
  • num_classes (Optional[int]) – integer with number of classes. Not nessesary to provide for binary problems.

  • pos_label (Optional[int]) – integer determining the positive class. Default is None which for binary problem is translated to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0, num_classes-1]

  • average (Optional[str]) –

    defines the reduction that is applied in the case of multiclass and multilabel input. Should be one of the following:

    • 'macro' [default]: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'micro': Calculate the metric globally, across all samples and classes. Cannot be used with multiclass input.

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support.

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (binary case):
>>> from torchmetrics import AveragePrecision
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = AveragePrecision(pos_label=1)
>>> average_precision(pred, target)
tensor(1.)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(num_classes=5, average=None)
>>> average_precision(pred, target)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the average precision score.

Return type

Union[Tensor, List[Tensor]]

Returns

tensor with average precision. If multiclass return list of such tensors, one for each class

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Functional Interface

torchmetrics.functional.average_precision(preds, target, num_classes=None, pos_label=None, average='macro', sample_weights=None)[source]

Computes the average precision score.

Parameters
  • preds (Tensor) – predictions from model (logits or probabilities)

  • target (Tensor) – ground truth values

  • num_classes (Optional[int]) – integer with number of classes. Not nessesary to provide for binary problems.

  • pos_label (Optional[int]) – integer determining the positive class. Default is None which for binary problem is translated to 1. For multiclass problems his argument should not be set as we iteratively change it in the range [0, num_classes-1]

  • average (Optional[str]) –

    defines the reduction that is applied in the case of multiclass and multilabel input. Should be one of the following:

    • 'macro' [default]: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'micro': Calculate the metric globally, across all samples and classes. Cannot be used with multiclass input.

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support.

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

  • sample_weights (Optional[Sequence]) – sample weights for each data point

Return type

Union[List[Tensor], Tensor]

Returns

tensor with average precision. If multiclass will return list of such tensors, one for each class

Example (binary case):
>>> from torchmetrics.functional import average_precision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision(pred, target, pos_label=1)
tensor(1.)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision(pred, target, num_classes=5, average=None)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]