Total Variation (TV)¶
Module Interface¶
- class torchmetrics.TotalVariation(reduction='sum', **kwargs)[source]
Computes Total Variation loss (TV).
As input to
forwardandupdatethe metric accepts the following inputimg(Tensor): A tensor of shape(N, C, H, W)consisting of images
As output of forward and compute the metric returns the following output
sdi(Tensor): ifreduction!='none'returns float scalar tensor with average TV value over sample else returns tensor of shape(N,)with TV values per sample
- Parameters
reduction¶ (
Literal[‘mean’, ‘sum’, ‘none’, None]) –a method to reduce metric score over samples
'mean': takes the mean over samples'sum': takes the sum over samplesNoneor'none': return the score per sample
kwargs¶ (
Any) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
reductionis not one of'sum','mean','none'orNone
Example
>>> import torch >>> from torchmetrics import TotalVariation >>> _ = torch.manual_seed(42) >>> tv = TotalVariation() >>> img = torch.rand(5, 3, 28, 28) >>> tv(img) tensor(7546.8018)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.total_variation(img, reduction='sum')[source]
Computes total variation loss.
- Parameters
- Return type
- Returns
A loss scalar value containing the total variation
- Raises
ValueError – If
reductionis not one of'sum','mean','none'orNoneRuntimeError – If
imgis not 4D tensor
Example
>>> import torch >>> from torchmetrics.functional import total_variation >>> _ = torch.manual_seed(42) >>> img = torch.rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018)