Classwise Wrapper¶
Module Interface¶
- class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]
Wrapper class for altering the output of classification metrics that returns multiple values to include label information.
- Parameters
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics import Accuracy, ClasswiseWrapper >>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}
- Example (labels as list of strings):
>>> import torch >>> from torchmetrics import Accuracy, ClasswiseWrapper >>> metric = ClasswiseWrapper( ... Accuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}
- Example (in metric collection):
>>> import torch >>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall >>> labels = ["horse", "fish", "dog"] >>> metric = MetricCollection( ... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), ... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Override this method to compute the final metric value from state variables synchronized across the distributed backend.
- forward(*args, **kwargs)[source]
forward
serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state.Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Return type
- reset()[source]
This method automatically resets the metric state variables to their default value.
- Return type