Shortcuts

Retrieval R-Precision

Module Interface

class torchmetrics.RetrievalRPrecision(empty_target_action='neg', ignore_index=None, **kwargs)[source]

Computes IR R-Precision.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • p2 (Tensor): A single-value tensor with the r-precision of the predictions preds w.r.t. the labels target.

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

Example

>>> from torchmetrics import RetrievalRPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalRPrecision()
>>> p2(preds, target, indexes=indexes)
tensor(0.7500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.retrieval_r_precision(preds, target)[source]

Computes the r-precision metric (for information retrieval). R-Precision is the fraction of relevant documents among all the top k retrieved documents where k is equal to the total number of relevant documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure Precision@K, k must be a positive integer.

Parameters
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

Return type

Tensor

Returns

a single-value tensor with the r-precision of the predictions preds w.r.t. the labels target.

Example

>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_r_precision(preds, target)
tensor(0.5000)