Explained Variance¶
Module Interface¶
- class torchmetrics.ExplainedVariance(multioutput='uniform_average', **kwargs)[source]
Computes explained variance:

Where
is a tensor of target values, and
is a tensor of predictions.Forward accepts
preds(float tensor):(N,)or(N, ...)(multioutput)target(long tensor):(N,)or(N, ...)(multioutput)
In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument
multioutputfor changing this behavior.- Parameters
Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is
'uniform_average'.):'raw_values'returns full set of scores'uniform_average'scores are uniformly averaged'variance_weighted'scores are weighted by their individual variances
kwargs¶ (
Dict[str,Any]) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
multioutputis not one of"raw_values","uniform_average"or"variance_weighted".
Example
>>> from torchmetrics import ExplainedVariance >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> explained_variance = ExplainedVariance() >>> explained_variance(preds, target) tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> explained_variance = ExplainedVariance(multioutput='raw_values') >>> explained_variance(preds, target) tensor([0.9677, 1.0000])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Computes explained variance over state.
Functional Interface¶
- torchmetrics.functional.explained_variance(preds, target, multioutput='uniform_average')[source]
Computes explained variance.
- Parameters
Defines aggregation in the case of multiple output scores. Can be one of the following strings):
'raw_values'returns full set of scores'uniform_average'scores are uniformly averaged'variance_weighted'scores are weighted by their individual variances
Example
>>> from torchmetrics.functional import explained_variance >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> explained_variance(preds, target) tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> explained_variance(preds, target, multioutput='raw_values') tensor([0.9677, 1.0000])