Total Variation (TV)¶
Module Interface¶
- class torchmetrics.TotalVariation(reduction='sum', **kwargs)[source]
Computes Total Variation loss (TV).
As input to
forward
andupdate
the metric accepts the following inputimg
(Tensor
): A tensor of shape(N, C, H, W)
consisting of images
As output of forward and compute the metric returns the following output
sdi
(Tensor
): ifreduction!='none'
returns float scalar tensor with average TV value over sample else returns tensor of shape(N,)
with TV values per sample
- Parameters
reduction¶ (
Literal
[‘mean’, ‘sum’, ‘none’, None]) –a method to reduce metric score over samples
'mean'
: takes the mean over samples'sum'
: takes the sum over samplesNone
or'none'
: return the score per sample
kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
reduction
is not one of'sum'
,'mean'
,'none'
orNone
Example
>>> import torch >>> from torchmetrics import TotalVariation >>> _ = torch.manual_seed(42) >>> tv = TotalVariation() >>> img = torch.rand(5, 3, 28, 28) >>> tv(img) tensor(7546.8018)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.total_variation(img, reduction='sum')[source]
Computes total variation loss.
- Parameters
- Return type
- Returns
A loss scalar value containing the total variation
- Raises
ValueError – If
reduction
is not one of'sum'
,'mean'
,'none'
orNone
RuntimeError – If
img
is not 4D tensor
Example
>>> import torch >>> from torchmetrics.functional import total_variation >>> _ = torch.manual_seed(42) >>> img = torch.rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018)