Using Classification Metrics¶
Input types¶
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (N
stands for the batch size and C
for number of classes):
Type 
preds shape 
preds dtype 
target shape 
target dtype 

Binary 
(N,) 

(N,) 

Multiclass 
(N,) 

(N,) 

Multiclass with logits or probabilities 
(N, C) 

(N,) 

Multilabel 
(N, …) 

(N, …) 

Multidimensional multiclass 
(N, …) 

(N, …) 

Multidimensional multiclass with logits or probabilities 
(N, C, …) 

(N, …) 

Note
All dimensions of size 1 (except N
) are “squeezed out” at the beginning, so
that, for example, a tensor of shape (N, 1)
is treated as (N, )
.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])
# Multiclass inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])
# Multiclass inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])
# Multilabel inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
Using the multiclass parameter¶
In some cases, you might have inputs which appear to be (multidimensional) multiclass but are actually binary/multilabel  for example, if both predictions and targets are integer (binary) tensors. Or it could be the other way around, you want to treat binary/multilabel inputs as 2class (multidimensional) multiclass inputs.
For these cases, the metrics where this distinction would make a difference, expose the
multiclass
argument. Let’s see how this is used on the example of
StatScores
metric.
First, let’s consider the case with label predictions with 2 classes, which we want to treat as binary.
from torchmetrics.functional import stat_scores
# These inputs are supposed to be binary, but appear as multiclass
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
As you can see below, by default the inputs are treated
as multiclass. We can set multiclass=False
to treat the inputs as binary 
which is the same as converting the predictions to float beforehand.
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Next, consider the opposite example: inputs are binary (as predictions are probabilities), but we would like to treat them as 2class multiclass, to obtain the metric for both classes.
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])
In this case we can set multiclass=True
, to treat the inputs as multiclass.
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])