Shortcuts

Retrieval Hit Rate

Module Interface

class torchmetrics.RetrievalHitRate(empty_target_action='neg', ignore_index=None, k=None, **kwargs)[source]

Computes IR HitRate.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then the Hit Rate will be computed as the mean of the Hit Rate over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If k parameter is not None or an integer larger than 0.

Example

>>> from torchmetrics import RetrievalHitRate
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([True, False, False, False, True, False, True])
>>> hr2 = RetrievalHitRate(k=2)
>>> hr2(preds, target, indexes=indexes)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

torchmetrics.functional.retrieval_hit_rate(preds, target, k=None)[source]

Computes the hit rate (for information retrieval). The hit rate is 1.0 if there is at least one relevant document among all the top k retrieved documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure HitRate@K, k must be a positive integer.

Parameters
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type

Tensor

Returns

a single-value tensor with the hit rate (at k) of the predictions preds w.r.t. the labels target.

Raises

ValueError – If k parameter is not None or an integer larger than 0

Example

>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_hit_rate(preds, target, k=2)
tensor(1.)