Spatial Distortion Index¶
Module Interface¶
- class torchmetrics.image.SpatialDistortionIndex(norm_order=1, window_size=7, reduction='elementwise_mean', **kwargs)[source]¶
Compute Spatial Distortion Index (SpatialDistortionIndex) also now as D_s.
The metric is used to compare the spatial distortion between two images.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): High resolution multispectral image of shape(N,C,H,W)
.
where H and W must be multiple of H’ and W’.
As output of forward and compute the metric returns the following output
sdi
(Tensor
): ifreduction!='none'
returns float scalar tensor with average SDI value over sample else returns tensor of shape(N,)
with SDI values per sample
- Parameters:
norm_order¶ (
int
) – Order of the norm applied on the difference.window_size¶ (
int
) – Window size of the filter applied to degrade the high resolution panchromatic image.reduction¶ (
Literal
['elementwise_mean'
,'sum'
,'none'
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import SpatialDistortionIndex >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { ... 'ms': torch.rand([16, 3, 16, 16]), ... 'pan': torch.rand([16, 3, 32, 32]), ... } >>> sdi = SpatialDistortionIndex() >>> sdi(preds, target) tensor(0.0090)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import SpatialDistortionIndex >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { ... 'ms': torch.rand([16, 3, 16, 16]), ... 'pan': torch.rand([16, 3, 32, 32]), ... } >>> metric = SpatialDistortionIndex() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import SpatialDistortionIndex >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { ... 'ms': torch.rand([16, 3, 16, 16]), ... 'pan': torch.rand([16, 3, 32, 32]), ... } >>> metric = SpatialDistortionIndex() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.spatial_distortion_index(preds, ms, pan, pan_lr=None, norm_order=1, window_size=7, reduction='elementwise_mean')[source]¶
Calculate Spatial Distortion Index (SpatialDistortionIndex) also known as D_s.
Metric is used to compare the spatial distortion between two images.
- Parameters:
pan_lr¶ (
Optional
[Tensor
]) – Low resolution panchromatic image.norm_order¶ (
int
) – Order of the norm applied on the difference.window_size¶ (
int
) – Window size of the filter applied to degrade the high resolution panchromatic image.reduction¶ (
Literal
['elementwise_mean'
,'sum'
,'none'
]) –A method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
- Return type:
- Returns:
Tensor with SpatialDistortionIndex score
- Raises:
TypeError – If
preds
,ms
,pan
andpan_lr
don’t have the same data type.ValueError – If
preds
,ms
,pan
andpan_lr
don’t haveBxCxHxW shape
.ValueError – If
preds
,ms
,pan
andpan_lr
don’t have the same batch and channel sizes.ValueError – If
preds
andpan
don’t have the same dimension.ValueError – If
ms
andpan_lr
don’t have the same dimension.ValueError – If
preds
andpan
don’t have dimension which is multiple of that ofms
.ValueError – If
norm_order
is not a positive integer.ValueError – If
window_size
is not a positive integer.
Example
>>> from torchmetrics.functional.image import spatial_distortion_index >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 32, 32]) >>> ms = torch.rand([16, 3, 16, 16]) >>> pan = torch.rand([16, 3, 32, 32]) >>> spatial_distortion_index(preds, ms, pan) tensor(0.0090)