Shortcuts

Exact Match

Module Interface

MultilabelExactMatch

class torchmetrics.classification.MultilabelExactMatch(num_labels, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torchmetrics.classification import MultilabelExactMatch
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelExactMatch(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelExactMatch
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelExactMatch(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelExactMatch
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> metric = MultilabelExactMatch(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0., 0.])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

Return type

None

Functional Interface

multilabel_exact_match

torchmetrics.functional.classification.multilabel_exact_match(preds, target, num_labels, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes Exact match (also known as subset accuracy) for multilabel tasks. Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import multilabel_exact_match
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_exact_match(preds, target, num_labels=3)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_exact_match
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_exact_match(preds, target, num_labels=3)
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_exact_match
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> multilabel_exact_match(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0., 0.])