Confusion Matrix¶
Module Interface¶
- class torchmetrics.ConfusionMatrix(num_classes, normalize=None, threshold=0.5, multilabel=False, compute_on_step=None, **kwargs)[source]
Computes the confusion matrix.
Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.
Forward accepts
preds
(float or long tensor):(N, ...)
or(N, C, ...)
where C is the number of classestarget
(long tensor):(N, ...)
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities or logits.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.If working with multilabel data, setting the
is_multilabel
argument toTrue
will make sure that a confusion matrix gets calculated per label.- Parameters
Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
threshold¶ (
float
) – Threshold for transforming probability or logit predictions to binary(0,1)
predictions, in the case of binary or multi-label inputs. Default value of0.5
corresponds to input being probabilities.multilabel¶ (
bool
) – determines if data is multilabel or not.compute_on_step¶ (
Optional
[bool
]) –Forward only calls
update()
and returns None if this is set to False.Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.
kwargs¶ (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (binary data):
>>> from torchmetrics import ConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
- Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(num_classes=3, multilabel=True) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Computes confusion matrix.
- Return type
- Returns
If
multilabel=False
this will be a[n_classes, n_classes]
tensor and ifmultilabel=True
this will be a[n_classes, 2, 2]
tensor.
Functional Interface¶
- torchmetrics.functional.confusion_matrix(preds, target, num_classes, normalize=None, threshold=0.5, multilabel=False)[source]
Computes the confusion matrix. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.
If preds and target are the same shape and preds is a float tensor, we use the
self.threshold
argument to convert into integer labels. This is the case for binary and multi-label probabilities or logits.If preds has an extra dimension as in the case of multi-class scores we perform an argmax on
dim=1
.If working with multilabel data, setting the
is_multilabel
argument toTrue
will make sure that a confusion matrix gets calculated per label.- Parameters
preds¶ (
Tensor
) – (float or long tensor), Either a(N, ...)
tensor with labels or(N, C, ...)
where C is the number of classes, tensor with labels/logits/probabilitiestarget¶ (
Tensor
) –target
(long tensor), tensor with shape(N, ...)
with ground true labelsNormalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
threshold¶ (
float
) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.multilabel¶ (
bool
) – determines if data is multilabel or not.
- Example (binary data):
>>> from torchmetrics import ConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
- Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(num_classes=3, multilabel=True) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Return type