Shortcuts

Classwise Wrapper

Module Interface

class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]

Wrapper class for altering the output of classification metrics that returns multiple values to include label information.

Parameters
  • metric (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.

  • labels (Optional[List[str]]) – list of strings indicating the different classes.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_0': tensor(0.5000),
'multiclassaccuracy_1': tensor(0.7500),
'multiclassaccuracy_2': tensor(0.)}
Example (labels as list of strings):
>>> import torch
>>> from torchmetrics import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
...    MulticlassAccuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.3333),
'multiclassaccuracy_fish': tensor(0.6667),
'multiclassaccuracy_dog': tensor(0.)}
Example (in metric collection):
>>> import torch
>>> from torchmetrics import ClasswiseWrapper, MetricCollection
>>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
...     {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels),
...     'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.),
 'multiclassaccuracy_fish': tensor(0.3333),
 'multiclassaccuracy_dog': tensor(0.4000),
 'multiclassrecall_horse': tensor(0.),
 'multiclassrecall_fish': tensor(0.3333),
 'multiclassrecall_dog': tensor(0.4000)}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Dict[str, Tensor]

forward(*args, **kwargs)[source]

forward serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state.

Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Return type

Any

reset()[source]

This method automatically resets the metric state variables to their default value.

Return type

None

update(*args, **kwargs)[source]

Override this method to update the state variables of your metric class.

Return type

None