Shortcuts

Stat Scores

Module Interface

StatScores

class torchmetrics.StatScores(threshold=0.5, top_k=None, reduce='micro', num_classes=None, ignore_index=None, mdmc_reduce=None, multiclass=None, task=None, average='macro', num_labels=None, multidim_average='global', validate_args=True, **kwargs)[source]

StatScores.

Note

From v0.10 an 'binary_*', 'multiclass_*', 'multilabel_*' version now exist of each classification metric. Moving forward we recommend using these versions. This base metric will still work as it did prior to v0.10 until v0.11. From v0.11 the task argument introduced in this metric will be required and the general order of arguments may change, such that this metric will just function as an single entrypoint to calling the three specialized versions.

Computes the number of true positives, false positives, true negatives, false negatives. Related to Type I and Type II errors and the confusion matrix.

The reduction method (how the statistics are aggregated) is controlled by the reduce parameter, and additionally by the mdmc_reduce parameter in the multi-dimensional multi-class case.

Accepts all inputs listed in Input types.

Parameters
  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • top_k (Optional[int]) – Number of the highest probability or logit score predictions considered finding the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs. Should be left at default (None) for all other types of inputs.

  • reduce (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer.

    • 'macro': Counts the statistics for each class separately (over all samples). Each statistic is represented by a (C,) tensor. Requires num_classes to be set.

    • 'samples': Counts the statistics for each sample separately (over all classes). Each statistic is represented by a (N, ) 1d tensor.

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_reduce.

  • num_classes (Optional[int]) – Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data.

  • ignore_index (Optional[int]) – Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and reduce='macro', the class statistics for the ignored class will all be returned as -1.

  • mdmc_reduce (Optional[str]) –

    Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class (see Input types for the definition of input types).

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then the outputs are concatenated together. In each sample the extra axes ... are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as the N axis for that sample.

    • 'global': In this case the N and ... dimensions of the inputs are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the reduce parameter applies as usual.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises
  • ValueError – If reduce is none of "micro", "macro" or "samples".

  • ValueError – If mdmc_reduce is none of None, "samplewise", "global".

  • ValueError – If reduce is set to "macro" and num_classes is not provided.

  • ValueError – If num_classes is set and ignore_index is not in the range 0 <= ignore_index < num_classes.

Example

>>> from torchmetrics.classification import StatScores
>>> preds  = torch.tensor([1, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> stat_scores = StatScores(reduce='macro', num_classes=3)
>>> stat_scores(preds, target)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])
>>> stat_scores = StatScores(reduce='micro')
>>> stat_scores(preds, target)
tensor([2, 2, 6, 2, 4])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the stat scores based on inputs passed in to update previously.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the reduce and mdmc_reduce (in case of multi-dimensional multi-class data) parameters:

  • If the data is not multi-dimensional multi-class, then

    • If reduce='micro', the shape will be (5, )

    • If reduce='macro', the shape will be (C, 5), where C stands for the number of classes

    • If reduce='samples', the shape will be (N, 5), where N stands for the number of samples

  • If the data is multi-dimensional multi-class and mdmc_reduce='global', then

    • If reduce='micro', the shape will be (5, )

    • If reduce='macro', the shape will be (C, 5)

    • If reduce='samples', the shape will be (N*X, 5), where X stands for the product of sizes of all “extra” dimensions of the data (i.e. all dimensions except for C and N)

  • If the data is multi-dimensional multi-class and mdmc_reduce='samplewise', then

    • If reduce='micro', the shape will be (N, 5)

    • If reduce='macro', the shape will be (N, C, 5)

    • If reduce='samples', the shape will be (N, X, 5)

update(preds, target)[source]

Update state with predictions and targets.

See Input types for more information on input types.

Parameters
  • preds (Tensor) – Predictions from model (probabilities, logits or labels)

  • target (Tensor) – Ground truth values

Return type

None

BinaryStatScores

class torchmetrics.classification.BinaryStatScores(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for binary tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryStatScores()
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryStatScores()
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> metric = BinaryStatScores(multidim_average='samplewise')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MulticlassStatScores

class torchmetrics.classification.MulticlassStatScores(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multiclass tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> metric = MulticlassStatScores(num_classes=3, average=None)
>>> metric(preds, target)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([
...   [0.16, 0.26, 0.58],
...   [0.22, 0.61, 0.17],
...   [0.71, 0.09, 0.20],
...   [0.05, 0.82, 0.13],
... ])
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> metric = MulticlassStatScores(num_classes=3, average=None)
>>> metric(preds, target)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro')
>>> metric(preds, target)
tensor([[3, 3, 9, 3, 6],
        [2, 4, 8, 4, 6]])
>>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average=None)
>>> metric(preds, target)
tensor([[[2, 1, 3, 0, 2],
         [0, 1, 3, 2, 2],
         [1, 1, 3, 1, 2]],
        [[0, 1, 4, 1, 1],
         [1, 1, 2, 2, 3],
         [1, 2, 2, 1, 2]]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

MultilabelStatScores

class torchmetrics.classification.MultilabelStatScores(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multilabel tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> metric = MultilabelStatScores(num_labels=3, average=None)
>>> metric(preds, target)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> metric = MultilabelStatScores(num_labels=3, average=None)
>>> metric(preds, target)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average='micro')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
>>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[[1, 1, 0, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 1, 0, 1, 1]],
        [[0, 0, 0, 2, 2],
         [0, 2, 0, 0, 0],
         [0, 0, 1, 1, 1]]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Functional Interface

stat_scores

torchmetrics.functional.stat_scores(preds, target, reduce='micro', mdmc_reduce=None, num_classes=None, top_k=None, threshold=0.5, multiclass=None, ignore_index=None, task=None, num_labels=None, average='micro', multidim_average='global', validate_args=True)[source]

Stat scores.

Note

From v0.10 an 'binary_*', 'multiclass_*', 'multilabel_*' version now exist of each classification metric. Moving forward we recommend using these versions. This base metric will still work as it did prior to v0.10 until v0.11. From v0.11 the task argument introduced in this metric will be required and the general order of arguments may change, such that this metric will just function as an single entrypoint to calling the three specialized versions.

Computes the number of true positives, false positives, true negatives, false negatives. Related to Type I and Type II errors and the confusion matrix.

The reduction method (how the statistics are aggregated) is controlled by the reduce parameter, and additionally by the mdmc_reduce parameter in the multi-dimensional multi-class case. Accepts all inputs listed in Input types.

Parameters
  • preds (Tensor) – Predictions from model (probabilities, logits or labels)

  • target (Tensor) – Ground truth values

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • reduce (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer.

    • 'macro': Counts the statistics for each class separately (over all samples). Each statistic is represented by a (C,) tensor. Requires num_classes to be set.

    • 'samples': Counts the statistics for each sample separately (over all classes). Each statistic is represented by a (N, ) 1d tensor.

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_reduce.

  • num_classes (Optional[int]) – Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data.

  • ignore_index (Optional[int]) – Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and reduce='macro', the class statistics for the ignored class will all be returned as -1.

  • mdmc_reduce (Optional[str]) –

    Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class (see Input types for the definition of input types).

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then the outputs are concatenated together. In each sample the extra axes ... are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as the N axis for that sample.

    • 'global': In this case the N and ... dimensions of the inputs are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the reduce parameter applies as usual.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the reduce and mdmc_reduce (in case of multi-dimensional multi-class data) parameters:

  • If the data is not multi-dimensional multi-class, then

    • If reduce='micro', the shape will be (5, )

    • If reduce='macro', the shape will be (C, 5), where C stands for the number of classes

    • If reduce='samples', the shape will be (N, 5), where N stands for the number of samples

  • If the data is multi-dimensional multi-class and mdmc_reduce='global', then

    • If reduce='micro', the shape will be (5, )

    • If reduce='macro', the shape will be (C, 5)

    • If reduce='samples', the shape will be (N*X, 5), where X stands for the product of sizes of all “extra” dimensions of the data (i.e. all dimensions except for C and N)

  • If the data is multi-dimensional multi-class and mdmc_reduce='samplewise', then

    • If reduce='micro', the shape will be (N, 5)

    • If reduce='macro', the shape will be (N, C, 5)

    • If reduce='samples', the shape will be (N, X, 5)

Raises
  • ValueError – If reduce is none of "micro", "macro" or "samples".

  • ValueError – If mdmc_reduce is none of None, "samplewise", "global".

  • ValueError – If reduce is set to "macro" and num_classes is not provided.

  • ValueError – If num_classes is set and ignore_index is not in the range [0, num_classes).

  • ValueError – If ignore_index is used with binary data.

  • ValueError – If inputs are multi-dimensional multi-class and mdmc_reduce is not provided.

Example

>>> from torchmetrics.functional import stat_scores
>>> preds  = torch.tensor([1, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> stat_scores(preds, target, reduce='macro', num_classes=3)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])
>>> stat_scores(preds, target, reduce='micro')
tensor([2, 2, 6, 2, 4])

binary_stat_scores

torchmetrics.functional.classification.binary_stat_scores(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for binary tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the multidim_average parameter:

  • If multidim_average is set to global, the shape will be (5,)

  • If multidim_average is set to samplewise, the shape will be (N, 5)

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 0, 1, 1, 0, 1])
>>> binary_stat_scores(preds, target)
tensor([2, 1, 2, 1, 3])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_stat_scores(preds, target)
tensor([2, 1, 2, 1, 3])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> binary_stat_scores(preds, target, multidim_average='samplewise')
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])

multiclass_stat_scores

torchmetrics.functional.classification.multiclass_stat_scores(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multiclass tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the shape will be (5,)

    • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

    • If average=None/'none', the shape will be (N, C, 5)

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([
...   [0.16, 0.26, 0.58],
...   [0.22, 0.61, 0.17],
...   [0.71, 0.09, 0.20],
...   [0.05, 0.82, 0.13],
... ])
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro')
tensor([[3, 3, 9, 3, 6],
        [2, 4, 8, 4, 6]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[[2, 1, 3, 0, 2],
         [0, 1, 3, 2, 2],
         [1, 1, 3, 1, 2]],
        [[0, 1, 4, 1, 1],
         [1, 1, 2, 2, 3],
         [1, 2, 2, 1, 2]]])

multilabel_stat_scores

torchmetrics.functional.classification.multilabel_stat_scores(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Computes the number of true positives, false positives, true negatives, false negatives and the support for multilabel tasks. Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

The influence of the additional dimension ... (if present) will be determined by the multidim_average argument.

Parameters
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal[‘micro’, ‘macro’, ‘weighted’, ‘none’]]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: Calculates statistics for each label and computes weighted average using their support

    • "none" or None: Calculates statistic for each label and applies no reduction

  • multidim_average (Literal[‘global’, ‘samplewise’]) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the shape will be (5,)

    • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

    • If average=None/'none', the shape will be (N, C, 5)

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = torch.tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average='micro')
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[[1, 1, 0, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 1, 0, 1, 1]],
        [[0, 0, 0, 2, 2],
         [0, 2, 0, 0, 0],
         [0, 0, 1, 1, 1]]])