Shortcuts

Multi-output Wrapper

Module Interface

class torchmetrics.MultioutputWrapper(base_metric, num_outputs, output_dim=- 1, remove_nans=True, squeeze_outputs=True)[source]

Wrap a base metric to enable it to support multiple outputs.

Several torchmetrics metrics, such as torchmetrics.regression.spearman.SpearmanCorrcoef lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn’t support any aggregation across outputs. This means if you set num_outputs to 2, .compute() will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped.

In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When remove_nans is passed, the class will remove the intersection of NaN containing “rows” upon each update for each output. For example, suppose a user uses MultioutputWrapper to wrap torchmetrics.regression.r2.R2Score with 2 outputs, one of which occasionally has missing labels for classes like R2Score is that this class supports removing NaN values (parameter remove_nans) on a per-output basis. When remove_nans is passed the wrapper will remove all rows

Parameters
  • base_metric (Metric) – Metric being wrapped.

  • num_outputs (int) – Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track.

  • output_dim (int) – Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as Accuracy where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.

  • remove_nans (bool) – Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension (N, ...) where N represents the length of the batch or dataset being passed in.

  • squeeze_outputs (bool) – If True, will squeeze the 1-item dimensions left after index_select is applied. This is sometimes unnecessary but harmless for metrics such as R2Score but useful for certain classification metrics that can’t handle additional 1-item dimensions.

Example

>>> # Mimic R2Score in `multioutput`, `raw_values` mode:
>>> import torch
>>> from torchmetrics import MultioutputWrapper, R2Score
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
[tensor(0.9654), tensor(0.9082)]
>>> # Classification metric where prediction and label tensors have different shapes.
>>> from torchmetrics import BinnedAveragePrecision
>>> target = torch.tensor([[1, 2], [2, 0], [1, 2]])
>>> preds = torch.tensor([
...     [[.1, .8], [.8, .05], [.1, .15]],
...     [[.1, .1], [.2, .3], [.7, .6]],
...     [[.002, .4], [.95, .45], [.048, .15]]
... ])
>>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2)
>>> binned_avg_precision(preds, target)
[[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute metrics.

Return type

List[Tensor]

forward(*args, **kwargs)[source]

Call underlying forward methods and aggregate the results if they’re non-null.

We override this method to ensure that state variables get copied over on the underlying metrics.

Return type

Any

reset()[source]

Reset all underlying metrics.

Return type

None

update(*args, **kwargs)[source]

Update each underlying metric with the corresponding output.

Return type

None

Read the Docs v: v0.8.1
Versions
latest
stable
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.