Shortcuts

Retrieval Precision

Module Interface

class torchmetrics.RetrievalPrecision(empty_target_action='neg', ignore_index=None, k=None, adaptive_k=False, **kwargs)[source]

Computes IR Precision.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then Precision will be computed as the mean of the Precision over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • adaptive_k (bool) – adjust k to min(k, number of documents) for each query

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Raises
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If k is not None or an integer larger than 0.

  • ValueError – If adaptive_k is not boolean.

Example

>>> from torchmetrics import RetrievalPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalPrecision(k=2)
>>> p2(preds, target, indexes=indexes)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.retrieval_precision(preds, target, k=None, adaptive_k=False)[source]

Computes the precision metric (for information retrieval). Precision is the fraction of relevant documents among all the retrieved documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure Precision@K, k must be a positive integer.

Parameters
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

  • adaptive_k (bool) – adjust k to min(k, number of documents) for each query

Return type

Tensor

Returns

a single-value tensor with the precision (at k) of the predictions preds w.r.t. the labels target.

Raises
  • ValueError – If k is not None or an integer larger than 0.

  • ValueError – If adaptive_k is not boolean.

Example

>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_precision(preds, target, k=2)
tensor(0.5000)