Classwise Wrapper¶
Module Interface¶
- class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]
Wrapper class for altering the output of classification metrics that returns multiple values to include label information.
- Parameters
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics import Accuracy, ClasswiseWrapper >>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}
- Example (labels as list of strings):
>>> import torch >>> from torchmetrics import Accuracy, ClasswiseWrapper >>> metric = ClasswiseWrapper( ... Accuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}
- Example (in metric collection):
>>> import torch >>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall >>> labels = ["horse", "fish", "dog"] >>> metric = MetricCollection( ... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), ... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Override this method to compute the final metric value from state variables synchronized across the distributed backend.