Spectral Distortion Index¶

Module Interface¶

class torchmetrics.image.SpectralDistortionIndex(p=1, reduction='elementwise_mean', **kwargs)[source]

Compute Spectral Distortion Index (SpectralDistortionIndex) also now as D_lambda.

The metric is used to compare the spectral distortion between two images.

As input to forward and update the metric accepts the following input

• preds (Tensor): Low resolution multispectral image of shape (N,C,H,W)

• target(:class:~torch.Tensor): High resolution fused image of shape (N,C,H,W)

As output of forward and compute the metric returns the following output

• sdi (Tensor): if reduction!='none' returns float scalar tensor with average SDI value over sample else returns tensor of shape (N,) with SDI values per sample

Parameters:
• p (int) – Large spectral differences

• reduction (Literal['elementwise_mean', 'sum', 'none']) –

a method to reduce metric score over labels.

• 'elementwise_mean': takes the mean (default)

• 'sum': takes the sum

• 'none': no reduction will be applied

• kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> sdi = SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
Return type:
Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)


Functional Interface¶

torchmetrics.functional.image.spectral_distortion_index(preds, target, p=1, reduction='elementwise_mean')[source]

Calculate Spectral Distortion Index (SpectralDistortionIndex) also known as D_lambda.

Metric is used to compare the spectral distortion between two images.

Parameters:
• preds (Tensor) – Low resolution multispectral image

• target (Tensor) – High resolution fused image

• p (int) – Large spectral differences

• reduction (Literal['elementwise_mean', 'sum', 'none']) –

a method to reduce metric score over labels.

• 'elementwise_mean': takes the mean (default)

• 'sum': takes the sum

• 'none': no reduction will be applied

Return type:

Tensor

Returns:

Tensor with SpectralDistortionIndex score

Raises:
• TypeError – If preds and target don’t have the same data type.

• ValueError – If preds and target don’t have BxCxHxW shape.

• ValueError – If p is not a positive integer.

Example

>>> from torchmetrics.functional.image import spectral_distortion_index
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> spectral_distortion_index(preds, target)
tensor(0.0234)


