Classwise Wrapper

Module Interface

class torchmetrics.wrappers.ClasswiseWrapper(metric, labels=None, prefix=None, postfix=None)[source]

Wrapper metric for altering the output of classification metrics.

This metric works together with classification metrics that returns multiple values (one value per class) such that label information can be automatically included in the output.

Parameters:
  • metric (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.

  • labels (Optional[List[str]]) – list of strings indicating the different classes.

  • prefix (Optional[str]) – string that is prepended to the metric names.

  • postfix (Optional[str]) – string that is appended to the metric names.

Example::

Basic example where the output of a metric is unwrapped into a dictionary with the class index as keys:

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_0': tensor(0.5000),
'multiclassaccuracy_1': tensor(0.7500),
'multiclassaccuracy_2': tensor(0.)}
Example::

Using custom name via prefix and postfix:

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-")
>>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc")
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric_pre(preds, target)  
{'acc-0': tensor(0.5000),
 'acc-1': tensor(0.7500),
 'acc-2': tensor(0.)}
>>> metric_post(preds, target)  
{'0-acc': tensor(0.5000),
 '1-acc': tensor(0.7500),
 '2-acc': tensor(0.)}
Example::

Providing labels as a list of strings:

>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
...    MulticlassAccuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.3333),
'multiclassaccuracy_fish': tensor(0.6667),
'multiclassaccuracy_dog': tensor(0.)}
Example::

Classwise can also be used in combination with MetricCollection. In this case, everything will be flattened into a single dictionary:

>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
...     {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels),
...     'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.),
 'multiclassaccuracy_fish': tensor(0.3333),
 'multiclassaccuracy_dog': tensor(0.4000),
 'multiclassrecall_horse': tensor(0.),
 'multiclassrecall_fish': tensor(0.3333),
 'multiclassrecall_dog': tensor(0.4000)}
compute()[source]

Compute metric.

Return type:

Dict[str, Tensor]

forward(*args, **kwargs)[source]

Calculate on batch and accumulate to global state.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> metric.update(torch.randint(3, (20,)), torch.randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
../_images/classwise_wrapper-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
../_images/classwise_wrapper-2.png
reset()[source]

Reset metric.

Return type:

None

update(*args, **kwargs)[source]

Update state.

Return type:

None