Shortcuts

Stat Scores

Module Interface

StatScores

class torchmetrics.StatScores(task: Literal['binary', 'multiclass', 'multilabel'], threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal['micro', 'macro', 'weighted', 'none']] = 'micro', multidim_average: Optional[Literal['global', 'samplewise']] = 'global', top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryStatScores, MulticlassStatScores and MultilabelStatScores for the specific details of each argument influence and examples.

Legacy Example:
>>> preds  = torch.tensor([1, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> stat_scores = StatScores(task="multiclass", num_classes=3, average='micro')
>>> stat_scores(preds, target)
tensor([2, 2, 6, 2, 4])
>>> stat_scores = StatScores(task="multiclass", num_classes=3, average=None)
>>> stat_scores(preds, target)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])

BinaryStatScores

class torchmetrics.classification.BinaryStatScores(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for binary tasks. Related to Type I and Type II errors.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • bss (Tensor): A tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the multidim_average parameter:

  • If multidim_average is set to global, the shape will be (5,)

  • If multidim_average is set to samplewise, the shape will be (N, 5)

Parameters
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryStatScores()
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryStatScores()
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> metric = BinaryStatScores(multidim_average='samplewise')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MulticlassStatScores

class torchmetrics.classification.MulticlassStatScores(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multiclass tasks. Related to Type I and Type II errors.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • mcss (Tensor): A tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global

  • If average='micro'/'macro'/'weighted', the shape will be (5,)

  • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise

  • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

  • If average=None/'none', the shape will be (N, C, 5)

Parameters
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> mcss = MulticlassStatScores(num_classes=3, average=None)
>>> mcss(preds, target)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([
...   [0.16, 0.26, 0.58],
...   [0.22, 0.61, 0.17],
...   [0.71, 0.09, 0.20],
...   [0.05, 0.82, 0.13],
... ])
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> mcss = MulticlassStatScores(num_classes=3, average=None)
>>> mcss(preds, target)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro')
>>> metric(preds, target)
tensor([[3, 3, 9, 3, 6],
        [2, 4, 8, 4, 6]])
>>> mcss = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average=None)
>>> mcss(preds, target)
tensor([[[2, 1, 3, 0, 2],
         [0, 1, 3, 2, 2],
         [1, 1, 3, 1, 2]],
        [[0, 1, 4, 1, 1],
         [1, 1, 2, 2, 3],
         [1, 2, 2, 1, 2]]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MultilabelStatScores

class torchmetrics.classification.MultilabelStatScores(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multilabel tasks. Related to Type I and Type II errors.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

As output to forward and compute the metric returns the following output:

  • mlss (Tensor): A tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global

  • If average='micro'/'macro'/'weighted', the shape will be (5,)

  • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise

  • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

  • If average=None/'none', the shape will be (N, C, 5)

Parameters
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> mlss = MultilabelStatScores(num_labels=3, average=None)
>>> mlss(preds, target)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> mlss = MultilabelStatScores(num_labels=3, average=None)
>>> mlss(preds, target)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average='micro')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
>>> mlss = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average=None)
>>> mlss(preds, target)
tensor([[[1, 1, 0, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 1, 0, 1, 1]],
        [[0, 0, 0, 2, 2],
         [0, 2, 0, 0, 0],
         [0, 0, 1, 1, 1]]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

stat_scores

torchmetrics.functional.stat_scores(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_stat_scores(), multiclass_stat_scores() and multilabel_stat_scores() for the specific details of each argument influence and examples.

Legacy Example:
>>> preds  = torch.tensor([1, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> stat_scores(preds, target, task='multiclass', num_classes=3, average='micro')
tensor([2, 2, 6, 2, 4])
>>> stat_scores(preds, target, task='multiclass', num_classes=3, average=None)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])
Return type

Tensor

binary_stat_scores

torchmetrics.functional.classification.binary_stat_scores(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for binary tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the multidim_average parameter:

  • If multidim_average is set to global, the shape will be (5,)

  • If multidim_average is set to samplewise, the shape will be (N, 5)

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 0, 1, 1, 0, 1])
>>> binary_stat_scores(preds, target)
tensor([2, 1, 2, 1, 3])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_stat_scores(preds, target)
tensor([2, 1, 2, 1, 3])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> binary_stat_scores(preds, target, multidim_average='samplewise')
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])

multiclass_stat_scores

torchmetrics.functional.classification.multiclass_stat_scores(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multiclass tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the shape will be (5,)

    • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

    • If average=None/'none', the shape will be (N, C, 5)

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([
...   [0.16, 0.26, 0.58],
...   [0.22, 0.61, 0.17],
...   [0.71, 0.09, 0.20],
...   [0.05, 0.82, 0.13],
... ])
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro')
tensor([[3, 3, 9, 3, 6],
        [2, 4, 8, 4, 6]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[[2, 1, 3, 0, 2],
         [0, 1, 3, 2, 2],
         [1, 1, 3, 1, 2]],
        [[0, 1, 4, 1, 1],
         [1, 1, 2, 2, 3],
         [1, 2, 2, 1, 2]]])

multilabel_stat_scores

torchmetrics.functional.classification.multilabel_stat_scores(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multilabel tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the shape will be (5,)

    • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

    • If average=None/'none', the shape will be (N, C, 5)

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average='micro')
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[[1, 1, 0, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 1, 0, 1, 1]],
        [[0, 0, 0, 2, 2],
         [0, 2, 0, 0, 0],
         [0, 0, 1, 1, 1]]])