Shortcuts

ROC¶

Module Interface¶

class torchmetrics.ROC(num_classes=None, pos_label=None, compute_on_step=None, **kwargs)[source]

Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.

Forward accepts

• preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass/multilabel) tensor with probabilities, where C is the number of classes/labels.

• target (long tensor): (N, ...) or (N, C, ...) with integer labels

Note

If either the positive class or negative class is completly missing in the target tensor, the roc values are not well-defined in this case and a tensor of zeros will be returned (either fpr or tpr depending on what class is missing) together with a warning.

Parameters
Example (binary case):
>>> from torchmetrics import ROC
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> roc = ROC(pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])

Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05],
...                      [0.05, 0.05, 0.05, 0.75]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> roc = ROC(num_classes=4)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])]
>>> thresholds
[tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]

Example (multilabel case):
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
...                      [0.3584, 0.7576, 0.1183],
...                      [0.2286, 0.3468, 0.1338],
...                      [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> roc = ROC(num_classes=3, pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]),
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
tensor([0., 1., 1., 1., 1.])]
>>> thresholds
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]


Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Return type
Returns

3-element tuple containing

fpr: tensor with false positive rates.

If multiclass, this is a list of such tensors, one for each class.

tpr: tensor with true positive rates.

If multiclass, this is a list of such tensors, one for each class.

thresholds:

thresholds used for computing false- and true-positive rates

update(preds, target)[source]

Update state with predictions and targets.

Parameters
Return type

None

Functional Interface¶

torchmetrics.functional.roc(preds, target, num_classes=None, pos_label=None, sample_weights=None)[source]

Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel input.

Note

If either the positive class or negative class is completly missing in the target tensor, the roc values are not well-defined in this case and a tensor of zeros will be returned (either fpr or tpr depending on what class is missing) together with a warning.

Parameters
Return type
Returns

3-element tuple containing

fpr: tensor with false positive rates.

If multiclass or multilabel, this is a list of such tensors, one for each class/label.

tpr: tensor with true positive rates.

If multiclass or multilabel, this is a list of such tensors, one for each class/label.

thresholds: tensor with thresholds used for computing false- and true postive rates

If multiclass or multilabel, this is a list of such tensors, one for each class/label.

Example (binary case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = roc(pred, target, pos_label=1)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])

Example (multiclass case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05],
...                      [0.05, 0.05, 0.05, 0.75]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> fpr, tpr, thresholds = roc(pred, target, num_classes=4)
>>> fpr
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])]
>>> thresholds
[tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500]),
tensor([1.7500, 0.7500, 0.0500])]

Example (multilabel case):
>>> from torchmetrics.functional import roc
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
...                      [0.3584, 0.7576, 0.1183],
...                      [0.2286, 0.3468, 0.1338],
...                      [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1)
>>> fpr
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])]
>>> thresholds
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]


© Copyright Copyright (c) 2020-2022, PyTorchLightning et al... Revision 60f1f185.

Built with Sphinx using a theme provided by Read the Docs.
Versions
latest
stable
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0