Shortcuts

Classwise Wrapper

Module Interface

class torchmetrics.ClasswiseWrapper(metric, labels=None)[source]

Wrapper class for altering the output of classification metrics that returns multiple values to include label information.

Parameters
  • metric (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.

  • labels (Optional[List[str]]) – list of strings indicating the different classes.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)}
Example (labels as list of strings):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper
>>> metric = ClasswiseWrapper(
...    Accuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)
{'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)}
Example (in metric collection):
>>> import torch
>>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
...     {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels),
...     'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000),
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Dict[str, Tensor]

update(*args, **kwargs)[source]

Override this method to update the state variables of your metric class.

Return type

None

Read the Docs v: v0.8.1
Versions
latest
stable
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.