Mean Squared Error (MSE)¶
Module Interface¶
- class torchmetrics.MeanSquaredError(squared=True, **kwargs)[source]
Computes mean squared error (MSE):
Where
is a tensor of target values, and
is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:mean_squared_error
(Tensor
): A tensor with the mean squared error
- Parameters
squared¶ (
bool
) – If True returns MSE value, if False returns RMSE value.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics import MeanSquaredError >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) >>> mean_squared_error = MeanSquaredError() >>> mean_squared_error(preds, target) tensor(0.8750)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.mean_squared_error(preds, target, squared=True)[source]
Computes mean squared error.
- Parameters
- Return type
- Returns
Tensor with MSE
Example
>>> from torchmetrics.functional import mean_squared_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_squared_error(x, y) tensor(0.2500)