Shortcuts

Confusion Matrix

Module Interface

ConfusionMatrix

class torchmetrics.ConfusionMatrix(task: typing_extensions.Literal[binary, multiclass, multilabel], threshold: float = 0.5, num_classes: int | None = None, num_labels: int | None = None, normalize: typing_extensions.Literal[true, pred, all, none] | None = None, ignore_index: int | None = None, validate_args: bool = True, **kwargs: Any)[source]

Compute the confusion matrix.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryConfusionMatrix, MulticlassConfusionMatrix and MultilabelConfusionMatrix() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(task="binary", num_classes=2)
>>> confmat(preds, target)
tensor([[2, 0],
        [1, 1]])
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
>>> confmat(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
>>> confmat(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])

Initialize task metric.

BinaryConfusionMatrix

class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for binary tasks.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • confusion_matrix (Tensor): A tensor containing a (2, 2) matrix

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> bcm = BinaryConfusionMatrix()
>>> bcm(preds, target)
tensor([[2, 0],
        [1, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01])
>>> bcm = BinaryConfusionMatrix()
>>> bcm(preds, target)
tensor([[2, 0],
        [1, 1]])
plot(val=None, ax=None, add_text=True, labels=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

  • add_text (bool) – if the value of each cell should be added to the plot

  • labels (Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()

(Source code, png, hires.png, pdf)

../_images/confusion_matrix-1.png

MulticlassConfusionMatrix

class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for multiclass tasks.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • confusion_matrix: [num_classes, num_classes] matrix

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassConfusionMatrix(num_classes=3)
>>> metric(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassConfusionMatrix(num_classes=3)
>>> metric(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
plot(val=None, ax=None, add_text=True, labels=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

  • add_text (bool) – if the value of each cell should be added to the plot

  • labels (Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()

(Source code, png, hires.png, pdf)

../_images/confusion_matrix-2.png

MultilabelConfusionMatrix

class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for multilabel tasks.

As input to ‘update’ the metric accepts the following input:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

As output of ‘compute’ the metric returns the following output:

  • confusion matrix: [num_labels,2,2] matrix

Parameters:
  • num_classes – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelConfusionMatrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelConfusionMatrix(num_labels=3)
>>> metric(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelConfusionMatrix(num_labels=3)
>>> metric(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
plot(val=None, ax=None, add_text=True, labels=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

  • add_text (bool) – if the value of each cell should be added to the plot

  • labels (Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()

(Source code, png, hires.png, pdf)

../_images/confusion_matrix-3.png

Functional Interface

confusion_matrix

torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_confusion_matrix(), multiclass_confusion_matrix() and multilabel_confusion_matrix() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> from torchmetrics.classification import ConfusionMatrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(task="binary")
>>> confmat(preds, target)
tensor([[2, 0],
        [1, 1]])
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
>>> confmat(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
>>> confmat(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])

binary_confusion_matrix

torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for binary tasks.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [2, 2] tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_confusion_matrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_confusion_matrix(preds, target)
tensor([[2, 0],
        [1, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_confusion_matrix(preds, target)
tensor([[2, 0],
        [1, 1]])

multiclass_confusion_matrix

torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for multiclass tasks.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [num_classes, num_classes] tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_confusion_matrix(preds, target, num_classes=3)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_confusion_matrix(preds, target, num_classes=3)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])

multilabel_confusion_matrix

torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for multilabel tasks.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [num_labels, 2, 2] tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_confusion_matrix(preds, target, num_labels=3)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_confusion_matrix(preds, target, num_labels=3)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])