AUROC¶
Module Interface¶
- class torchmetrics.AUROC(num_classes=None, pos_label=None, average='macro', max_fpr=None, **kwargs)[source]
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC). Works for both binary, multilabel and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Forward accepts
preds(float tensor):(N, ...)(binary) or(N, C, ...)(multiclass) tensor with probabilities, where C is the number of classes.target(long tensor):(N, ...)or(N, C, ...)with integer labels
For non-binary input, if the
predsandtargettensor have the same size the input will be interpretated as multilabel and ifpredshave one dimension more than thetargettensor the input will be interpretated as multiclass.Note
If either the positive class or negative class is completly missing in the target tensor, the auroc score is meaningless in this case and a score of 0 will be returned together with an warning.
- Parameters
num_classes¶ (
Optional[int]) –integer with number of classes for multi-label and multiclass problems.
Should be set to
Nonefor binary problemspos_label¶ (
Optional[int]) – integer determining the positive class. Default isNonewhich for binary problem is translated to 1. For multiclass problems this argument should not be set as we iteratively change it in the range[0, num_classes-1]'micro'computes metric globally. Only works for multilabel problems'macro'computes metric for each class and uniformly averages them'weighted'computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance)Nonecomputes and returns the metric per class
max_fpr¶ (
Optional[float]) – If notNone, calculates standardized partial AUC over the range[0, max_fpr]. Should be a float between 0 and 1.kwargs¶ (
Dict[str,Any]) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
averageis none ofNone,"macro"or"weighted".ValueError – If
max_fpris not afloatin the range(0, 1].RuntimeError – If
PyTorch versionisbelow 1.6sincemax_fprrequirestorch.bucketizewhich is not available below 1.6.ValueError – If the mode of data (binary, multi-label, multi-class) changes between batches.
- Example (binary case):
>>> from torchmetrics import AUROC >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc = AUROC(pos_label=1) >>> auroc(preds, target) tensor(0.5000)
- Example (multiclass case):
>>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) >>> auroc = AUROC(num_classes=3) >>> auroc(preds, target) tensor(0.7778)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.auroc(preds, target, num_classes=None, pos_label=None, average='macro', max_fpr=None, sample_weights=None)[source]
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
For non-binary input, if the
predsandtargettensor have the same size the input will be interpretated as multilabel and ifpredshave one dimension more than thetargettensor the input will be interpretated as multiclass.Note
If either the positive class or negative class is completly missing in the target tensor, the auroc score is meaningless in this case and a score of 0 will be returned together with a warning.
- Parameters
preds¶ (
Tensor) – predictions from model (logits or probabilities)num_classes¶ (
Optional[int]) – integer with number of classes for multi-label and multiclass problems. Should be set toNonefor binary problemspos_label¶ (
Optional[int]) – integer determining the positive class. Default isNonewhich for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]'micro'computes metric globally. Only works for multilabel problems'macro'computes metric for each class and uniformly averages them'weighted'computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance)Nonecomputes and returns the metric per class
max_fpr¶ (
Optional[float]) – If notNone, calculates standardized partial AUC over the range[0, max_fpr]. Should be a float between 0 and 1.sample_weights¶ (
Optional[Sequence]) – sample weights for each data point
- Raises
ValueError – If
max_fpris not afloatin the range(0, 1].RuntimeError – If
PyTorch versionis below 1.6 since max_fpr requirestorch.bucketizewhich is not available below 1.6.ValueError – If
max_fpris not set toNoneand the mode isnot binarysince partial AUC computation is not available in multilabel/multiclass.ValueError – If
averageis none ofNone,"macro"or"weighted".
- Example (binary case):
>>> from torchmetrics.functional import auroc >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc(preds, target, pos_label=1) tensor(0.5000)
- Example (multiclass case):
>>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) >>> auroc(preds, target, num_classes=3) tensor(0.7778)
- Return type