Shortcuts

Retrieval Mean Average Precision (MAP)

Module Interface

class torchmetrics.RetrievalMAP(empty_target_action='neg', ignore_index=None, **kwargs)[source]

Computes Mean Average Precision.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then MAP will be computed as the mean of the Average Precisions over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

Example

>>> from torchmetrics import RetrievalMAP
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> rmap = RetrievalMAP()
>>> rmap(preds, target, indexes=indexes)
tensor(0.7917)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.retrieval_average_precision(preds, target)[source]

Computes average precision (for information retrieval), as explained in IR Average precision.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised.

Parameters
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

Return type

Tensor

Returns

a single-value tensor with the average precision (AP) of the predictions preds w.r.t. the labels target.

Example

>>> from torchmetrics.functional import retrieval_average_precision
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_average_precision(preds, target)
tensor(0.8333)