Shortcuts

Module metrics

Base class

The base Metric class is an abstract base class that are used as the building block for all other Module metrics.

class torchmetrics.Metric(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Base class for all metrics present in the Metrics API.

Implements add_state(), forward(), reset() and a few other things to handle distributed synchronization and per-step metric computation.

Override update() and compute() functions to implement your own metric. Use add_state() to register metric state variables which keep track of state on each call of update() and are synchronized across processes when compute() is called.

Note

Metric state variables can either be torch.Tensors or an empty list which can we used to store torch.Tensors`.

Note

Different metrics only override update() and not forward(). A call to update() is valid, but it won’t return the metric value at the current step. A call to forward() automatically calls update() and also returns the metric value at the current step.

Parameters
  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Adds metric state variable. Only used by subclasses.

Parameters
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a torch.Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is torch.Tensor, the synced value will be a stacked torch.Tensor across the process dimension if the metric state was a torch.Tensor. The original torch.Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Note

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Raises
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", None.

Return type

None

clone()[source]

Make a copy of the metric.

Return type

Metric

abstract compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Any

double()[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

float()[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

forward(*args, **kwargs)[source]

Automatically calls update().

Returns the metric value over inputs if compute_on_step is True.

Return type

Any

half()[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

persistent(mode=False)[source]

Method for post-init to change if metric states should be saved to its state_dict.

Return type

None

reset()[source]

This method automatically resets the metric state variables to their default value.

Return type

None

set_dtype(dst_type)[source]

Special version of type for transferring all metric states to specific dtype :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: Union[str, dtype] :param _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: the desired type :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: type or string

Return type

None

state_dict(destination=None, prefix='', keep_vars=False)[source]

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Returns

a dictionary containing a whole state of the module

Return type

dict

Example:

>>> module.state_dict().keys()
['bias', 'weight']
sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=<function jit_distributed_available>)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=<function jit_distributed_available>)[source]

Context manager to synchronize the states between processes when running in a distributed setting and restore the local cache states after yielding.

Parameters
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type

Generator

type(dst_type)[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters

should_unsync (bool) – Whether to perform unsync

Return type

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type

None

property device: torch.device[source]

Return the device of the metric.

Return type

device

Basic Aggregation Metrics

Torchmetrics comes with a number of metrics for aggregation of basic statistics: mean, max, min etc. of either tensors or native python floats.

CatMetric

class torchmetrics.CatMetric(nan_strategy='warn', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Concatenate a stream of values.

Parameters
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torchmetrics import CatMetric
>>> metric = CatMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor([1., 2., 3.])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the aggregated value.

Return type

Tensor

update(value)[source]

Update state with data.

Parameters

value (Union[float, Tensor]) – Either a float or tensor containing data. Additional tensor dimensions will be flattened

Return type

None

MaxMetric

class torchmetrics.MaxMetric(nan_strategy='warn', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Aggregate a stream of value into their maximum value.

Parameters
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torchmetrics import MaxMetric
>>> metric = MaxMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor(3.)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

update(value)[source]

Update state with data.

Parameters

value (Union[float, Tensor]) – Either a float or tensor containing data. Additional tensor dimensions will be flattened

Return type

None

MeanMetric

class torchmetrics.MeanMetric(nan_strategy='warn', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Aggregate a stream of value into their mean value.

Parameters

nan_strategy (Union[str, float]) –

options:
  • 'error': if any nan values are encounted will give a RuntimeError

  • 'warn': if any nan values are encounted will give a warning and continue

  • 'ignore': all nan values are silently removed

  • a float: if a float is provided will impude any nan values with this value

compute_on_step:

Forward only calls update() and returns None if this is set to False. default: True

dist_sync_on_step:

Synchronize metric state across processes at each forward() before returning the value at the step.

process_group:

Specify the process group on which synchronization is called. default: None (which selects the entire world)

dist_sync_fn:

Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torchmetrics import MeanMetric
>>> metric = MeanMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor([2.])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the aggregated value.

Return type

Tensor

update(value, weight=1.0)[source]

Update state with data.

Parameters
  • value (Union[float, Tensor]) – Either a float or tensor containing data. Additional tensor dimensions will be flattened

  • weight (Union[float, Tensor]) – Either a float or tensor containing weights for calculating the average. Shape of weight should be able to broadcast with the shape of value. Default to 1.0 corresponding to simple harmonic average.

Return type

None

MinMetric

class torchmetrics.MinMetric(nan_strategy='warn', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Aggregate a stream of value into their minimum value.

Parameters
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torchmetrics import MinMetric
>>> metric = MinMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor(1.)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

update(value)[source]

Update state with data.

Parameters

value (Union[float, Tensor]) – Either a float or tensor containing data. Additional tensor dimensions will be flattened

Return type

None

SumMetric

class torchmetrics.SumMetric(nan_strategy='warn', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Aggregate a stream of value into their sum.

Parameters
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torchmetrics import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor(6.)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

update(value)[source]

Update state with data.

Parameters

value (Union[float, Tensor]) – Either a float or tensor containing data. Additional tensor dimensions will be flattened

Return type

None

Audio Metrics

About Audio Metrics

For the purposes of audio metrics, inputs (predictions, targets) must have the same size. If the input is 1D tensors the output will be a scalar. If the input is multi-dimensional with shape [...,time] the metric will be computed over the time dimension.

>>> import torch
>>> from torchmetrics import SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SNR()
>>> snr_val = snr(preds, target)
>>> snr_val
tensor(16.1805)

PESQ

class torchmetrics.PESQ(fs, mode, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

This is a wrapper for the pesq package [1]. . Note that input will be moved to cpu to perform the metric calculation.

Note

using this metrics requires you to have pesq install. Either install as pip install torchmetrics[audio] or pip install pesq

Forward accepts

  • preds: shape [...,time]

  • target: shape [...,time]

Parameters
  • fs (int) – sampling frequency, should be 16000 or 8000 (Hz)

  • mode (str) – ‘wb’ (wide-band) or ‘nb’ (narrow-band)

  • keep_same_device – whether to move the pesq value to the device of preds

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], Tensor]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Raises

Example

>>> from torchmetrics.audio import PESQ
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PESQ(8000, 'nb')
>>> nb_pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PESQ(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)

References

[1] https://github.com/ludlows/python-pesq

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average PESQ.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

PIT

class torchmetrics.PIT(metric_func, eval_func='max', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, **kwargs)[source]

Permutation invariant training (PIT). The PIT implements the famous Permutation Invariant Training method.

[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.

Forward accepts

  • preds: shape [batch, spk, ...]

  • target: shape [batch, spk, ...]

Parameters
  • metric_func (Callable) – a metric function accept a batch of target and estimate, i.e. metric_func(preds[:, i, …], target[:, j, …]), and returns a batch of metric tensors [batch]

  • eval_func (str) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], Tensor]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

  • kwargs (Dict[str, Any]) – additional args for metric_func

Returns

average PIT metric

Example

>>> import torch
>>> from torchmetrics import PIT
>>> from torchmetrics.functional import si_snr
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = PIT(si_snr, 'max')
>>> pit(preds, target)
tensor(-2.1065)
Reference:

[1] D. Yu, M. Kolbaek, Z.-H. Tan, J. Jensen, Permutation invariant training of deep models for speaker-independent multi-talker speech separation, in: 2017 IEEE Int. Conf. Acoust. Speech Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average PIT metric.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

SI_SDR

class torchmetrics.SI_SDR(zero_mean=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Scale-invariant signal-to-distortion ratio (SI-SDR). The SI-SDR value is in general considered an overall measure of how good a source sound.

Forward accepts

  • preds: shape [...,time]

  • target: shape [...,time]

Parameters
  • zero_mean (bool) – if to zero mean target and preds or not

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], Tensor]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

TypeError – if target and preds have a different shape

Returns

average si-sdr value

Example

>>> import torch
>>> from torchmetrics import SI_SDR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = SI_SDR()
>>> si_sdr_val = si_sdr(preds, target)
>>> si_sdr_val
tensor(18.4030)

References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average SI-SDR.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

SI_SNR

class torchmetrics.SI_SNR(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Scale-invariant signal-to-noise ratio (SI-SNR).

Forward accepts

  • preds: shape [...,time]

  • target: shape [...,time]

Parameters
  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], Tensor]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

TypeError – if target and preds have a different shape

Returns

average si-snr value

Example

>>> import torch
>>> from torchmetrics import SI_SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = SI_SNR()
>>> si_snr_val = si_snr(preds, target)
>>> si_snr_val
tensor(15.0918)

References

[1] Y. Luo and N. Mesgarani, “TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech Separation,” 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. 696-700, doi: 10.1109/ICASSP.2018.8462116.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average SI-SNR.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

SNR

class torchmetrics.SNR(zero_mean=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Signal-to-noise ratio (SNR):

\text{SNR} = \frac{P_{signal}}{P_{noise}}

where P denotes the power of each signal. The SNR metric compares the level of the desired signal to the level of background noise. Therefore, a high value of SNR means that the audio is clear.

Forward accepts

  • preds: shape [..., time]

  • target: shape [..., time]

Parameters
  • zero_mean (bool) – if to zero mean target and preds or not

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], Tensor]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

TypeError – if target and preds have a different shape

Returns

average snr value

Example

>>> import torch
>>> from torchmetrics import SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SNR()
>>> snr_val = snr(preds, target)
>>> snr_val
tensor(16.1805)

References

[1] Le Roux, Jonathan, et al. “SDR half-baked or well done.” IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) 2019.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average SNR.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

STOI

class torchmetrics.STOI(fs, extended=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

STOI (Short Term Objective Intelligibility, see [2,3]), a wrapper for the pystoi package [1]. Note that input will be moved to cpu to perform the metric calculation.

Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech intelligibility. Description taken from [Cees Taal’s website](http://www.ceestaal.nl/code/).

Note

using this metrics requires you to have pystoi install. Either install as pip install torchmetrics[audio] or pip install pystoi

Forward accepts

  • preds: shape [...,time]

  • target: shape [...,time]

Parameters
  • fs (int) – sampling frequency (Hz)

  • extended (bool) – whether to use the extended STOI described in [4]

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], Tensor]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Returns

average STOI value

Raises

ModuleNotFoundError – If pystoi package is not installed

Example

>>> from torchmetrics.audio import STOI
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = STOI(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)

References

[1] https://github.com/mpariente/pystoi

[2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen ‘A Short-Time Objective Intelligibility Measure for Time-Frequency Weighted Noisy Speech’, ICASSP 2010, Texas, Dallas.

[3] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen ‘An Algorithm for Intelligibility Prediction of Time-Frequency Weighted Noisy Speech’, IEEE Transactions on Audio, Speech, and Language Processing, 2011.

[4] J. Jensen and C. H. Taal, ‘An Algorithm for Predicting the Intelligibility of Speech Masked by Modulated Noise Maskers’, IEEE Transactions on Audio, Speech and Language Processing, 2016.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes average STOI.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Classification Metrics

Input types

For the purposes of classification metrics, inputs (predictions and targets) are split into these categories (N stands for the batch size and C for number of classes):

*dtype binary means integers that are either 0 or 1

Type

preds shape

preds dtype

target shape

target dtype

Binary

(N,)

float

(N,)

binary*

Multi-class

(N,)

int

(N,)

int

Multi-class with logits or probabilities

(N, C)

float

(N,)

int

Multi-label

(N, …)

float

(N, …)

binary*

Multi-dimensional multi-class

(N, …)

int

(N, …)

int

Multi-dimensional multi-class with logits or probabilities

(N, C, …)

float

(N, …)

int

Note

All dimensions of size 1 (except N) are “squeezed out” at the beginning, so that, for example, a tensor of shape (N, 1) is treated as (N, ).

When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types

# Binary inputs
binary_preds  = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])

# Multi-class inputs
mc_preds  = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])

# Multi-class inputs with probabilities
mc_preds_probs  = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])

# Multi-label inputs
ml_preds  = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])

Using the multiclass parameter

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class but are actually binary/multi-label - for example, if both predictions and targets are integer (binary) tensors. Or it could be the other way around, you want to treat binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.

For these cases, the metrics where this distinction would make a difference, expose the multiclass argument. Let’s see how this is used on the example of StatScores metric.

First, let’s consider the case with label predictions with 2 classes, which we want to treat as binary.

from torchmetrics.functional import stat_scores

# These inputs are supposed to be binary, but appear as multi-class
preds  = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

As you can see below, by default the inputs are treated as multi-class. We can set multiclass=False to treat the inputs as binary - which is the same as converting the predictions to float beforehand.

>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
        [1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])

Next, consider the opposite example: inputs are binary (as predictions are probabilities), but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.

preds  = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])

In this case we can set multiclass=True, to treat the inputs as multi-class.

>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, multiclass=True)
tensor([[1, 1, 1, 0, 1],
        [1, 0, 1, 1, 2]])

Accuracy

class torchmetrics.Accuracy(threshold=0.5, num_classes=None, average='micro', mdmc_average='global', ignore_index=None, top_k=None, multiclass=None, subset_accuracy=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Accuracy:

\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

For multi-class and multi-dimensional multi-class data with probability or logits predictions, the parameter top_k generalizes this metric to a Top-K accuracy metric: for each sample the top-K highest probability or logit score items are considered to find the correct label.

For multi-label and multi-dimensional multi-class inputs, this metric computes the “global” accuracy by default, which counts all labels or sub-samples separately. This can be changed to subset accuracy (which requires all labels or sub-samples in the sample to be correctly predicted) by setting subset_accuracy=True.

Accepts all input types listed in Input types.

Parameters
  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • average (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

    Note

    If 'none' and a given class doesn’t occur in the preds or target, the value for the class will be nan.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... (see Input types) as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs (see Input types) are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • subset_accuracy (bool) –

    Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types).

    • For multi-label inputs, if the parameter is set to True, then all labels for each sample must be correctly predicted for the sample to count as correct. If it is set to False, then all labels are counted separately - this is equivalent to flattening inputs beforehand (i.e. preds = preds.flatten() and same for target).

    • For multi-dimensional multi-class inputs, if the parameter is set to True, then all sub-sample (on the extra axis) must be correct for the sample to be counted as correct. If it is set to False, then all sub-samples are counter separately - this is equivalent, in the case of label predictions, to flattening the inputs beforehand (i.e. preds = preds.flatten() and same for target). Note that the top_k parameter still applies in both cases, if set.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Raises
  • ValueError – If top_k is not an integer larger than 0.

  • ValueError – If average is none of "micro", "macro", "weighted", "samples", "none", None.

  • ValueError – If two different input modes are provided, eg. using multi-label with multi-class.

  • ValueError – If top_k parameter is set for multi-label inputs.

Example

>>> import torch
>>> from torchmetrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy = Accuracy()
>>> accuracy(preds, target)
tensor(0.5000)
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy = Accuracy(top_k=2)
>>> accuracy(preds, target)
tensor(0.6667)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes accuracy based on inputs passed in to update previously.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets. See Input types for more information on input types.

Parameters
  • preds (Tensor) – Predictions from model (logits, probabilities, or labels)

  • target (Tensor) – Ground truth labels

Return type

None

AveragePrecision

class torchmetrics.AveragePrecision(num_classes=None, pos_label=None, average='macro', compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- vs-the-rest approach.

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) with integer labels

Parameters
  • num_classes (Optional[int]) – integer with number of classes. Not nessesary to provide for binary problems.

  • pos_label (Optional[int]) – integer determining the positive class. Default is None which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]

  • average (Optional[str]) –

    defines the reduction that is applied in the case of multiclass and multilabel input. Should be one of the following:

    • 'macro' [default]: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'micro': Calculate the metric globally, across all samples and classes. Cannot be used with multiclass input.

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support.

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example (binary case):
>>> from torchmetrics import AveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = AveragePrecision(pos_label=1)
>>> average_precision(pred, target)
tensor(1.)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(num_classes=5, average=None)
>>> average_precision(pred, target)
[tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the average precision score.

Return type

Union[Tensor, List[Tensor]]

Returns

tensor with average precision. If multiclass will return list of such tensors, one for each class

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

AUC

class torchmetrics.AUC(reorder=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Area Under the Curve (AUC) using the trapezoidal rule

Forward accepts two input tensors that should be 1D and have the same number of elements

Parameters
  • reorder (bool) – AUC expects its first input to be sorted. If this is not the case, setting this argument to True will use a stable sorting algorithm to sort the input in descending order

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes AUC based on inputs passed in to update previously.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model (probabilities, or labels)

  • target (Tensor) – Ground truth labels

Return type

None

AUROC

class torchmetrics.AUROC(num_classes=None, pos_label=None, average='macro', max_fpr=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC). Works for both binary, multilabel and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) or (N, C, ...) with integer labels

For non-binary input, if the preds and target tensor have the same size the input will be interpretated as multilabel and if preds have one dimension more than the target tensor the input will be interpretated as multiclass.

Note

If either the positive class or negative class is completly missing in the target tensor, the auroc score is meaningless in this case and a score of 0 will be returned together with an warning.

Parameters
  • num_classes (Optional[int]) – integer with number of classes for multi-label and multiclass problems. Should be set to None for binary problems

  • pos_label (Optional[int]) – integer determining the positive class. Default is None which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]

  • average (Optional[str]) –

    • 'micro' computes metric globally. Only works for multilabel problems

    • 'macro' computes metric for each class and uniformly averages them

    • 'weighted' computes metric for each class and does a weighted-average, where each class is weighted by their support (accounts for class imbalance)

    • None computes and returns the metric per class

  • max_fpr (Optional[float]) – If not None, calculates standardized partial AUC over the range [0, max_fpr]. Should be a float between 0 and 1.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Raises
  • ValueError – If average is none of None, "macro" or "weighted".

  • ValueError – If max_fpr is not a float in the range (0, 1].

  • RuntimeError – If PyTorch version is below 1.6 since max_fpr requires torch.bucketize which is not available below 1.6.

  • ValueError – If the mode of data (binary, multi-label, multi-class) changes between batches.

Example (binary case):
>>> from torchmetrics import AUROC
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> auroc = AUROC(pos_label=1)
>>> auroc(preds, target)
tensor(0.5000)
Example (multiclass case):
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
...                       [0.05, 0.90, 0.05],
...                       [0.05, 0.05, 0.90],
...                       [0.85, 0.05, 0.10],
...                       [0.10, 0.10, 0.80]])
>>> target = torch.tensor([0, 1, 1, 2, 2])
>>> auroc = AUROC(num_classes=3)
>>> auroc(preds, target)
tensor(0.7778)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes AUROC based on inputs passed in to update previously.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model (probabilities, or labels)

  • target (Tensor) – Ground truth labels

Return type

None

BinnedAveragePrecision

class torchmetrics.BinnedAveragePrecision(num_classes, thresholds=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes the average precision score, which summarises the precision recall curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- vs-the-rest approach.

Computation is performed in constant-memory by computing precision and recall for thresholds buckets/thresholds (evenly distributed between 0 and 1).

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) with integer labels

Parameters
  • num_classes (int) – integer with number of classes. Not nessesary to provide for binary problems.

  • thresholds (Union[int, Tensor, List[float], None]) – list or tensor with specific thresholds or a number of bins from linear sampling. It is used for computation will lead to more detailed curve and accurate estimates, but will be slower and consume more memory

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises

ValueError – If thresholds is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedAveragePrecision
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision = BinnedAveragePrecision(num_classes=1, thresholds=10)
>>> average_precision(pred, target)
tensor(1.0000)
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = BinnedAveragePrecision(num_classes=5, thresholds=10)
>>> average_precision(pred, target)
[tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Returns float tensor of size n_classes.

Return type

Union[List[Tensor], Tensor]

BinnedPrecisionRecallCurve

class torchmetrics.BinnedPrecisionRecallCurve(num_classes, thresholds=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.

Computation is performed in constant-memory by computing precision and recall for thresholds buckets/thresholds (evenly distributed between 0 and 1).

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) or (N, C, ...) with integer labels

Parameters
  • num_classes (int) – integer with number of classes. For binary, set to 1.

  • thresholds (Union[int, Tensor, List[float], None]) – list or tensor with specific thresholds or a number of bins from linear sampling. It is used for computation will lead to more detailed curve and accurate estimates, but will be slower and consume more memory.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises

ValueError – If thresholds is not a int, list or tensor

Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, thresholds=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, thresholds=3)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision   
[tensor([0.2500, 1.0000, 1.0000, 1.0000]),
tensor([0.2500, 1.0000, 1.0000, 1.0000]),
tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]),
tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]),
tensor([2.5000e-07, 1.0000e+00, 1.0000e+00, 1.0000e+00])]
>>> recall   
[tensor([1.0000, 1.0000, 0.0000, 0.0000]),
tensor([1.0000, 1.0000, 0.0000, 0.0000]),
tensor([1.0000, 0.0000, 0.0000, 0.0000]),
tensor([1.0000, 0.0000, 0.0000, 0.0000]),
tensor([0., 0., 0., 0.])]
>>> thresholds   
[tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000]),
tensor([0.0000, 0.5000, 1.0000])]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Returns float tensor of size n_classes.

Return type

Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]

update(preds, target)[source]
Args

preds: (n_samples, n_classes) tensor target: (n_samples, n_classes) tensor

Return type

None

BinnedRecallAtFixedPrecision

class torchmetrics.BinnedRecallAtFixedPrecision(num_classes, min_precision, thresholds=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes the higest possible recall value given the minimum precision thresholds provided.

Computation is performed in constant-memory by computing precision and recall for thresholds buckets/thresholds (evenly distributed between 0 and 1).

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) with integer labels

Parameters
  • num_classes (int) – integer with number of classes. Provide 1 for for binary problems.

  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, Tensor, List[float], None]) – list or tensor with specific thresholds or a number of bins from linear sampling. It is used for computation will lead to more detailed curve and accurate estimates, but will be slower and consume more memory

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises

ValueError – If thresholds is not a list or tensor

Example (binary case):
>>> from torchmetrics import BinnedRecallAtFixedPrecision
>>> pred = torch.tensor([0, 0.2, 0.5, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> average_precision = BinnedRecallAtFixedPrecision(num_classes=1, thresholds=10, min_precision=0.5)
>>> average_precision(pred, target)
(tensor(1.0000), tensor(0.1111))
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision = BinnedRecallAtFixedPrecision(num_classes=5, thresholds=10, min_precision=0.5)
>>> average_precision(pred, target)   
(tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]),
tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Returns float tensor of size n_classes.

Return type

Tuple[Tensor, Tensor]

CalibrationError

class torchmetrics.CalibrationError(n_bins=15, norm='l1', compute_on_step=False, dist_sync_on_step=False, process_group=None)[source]

Computes the Top-label Calibration Error Three different norms are implemented, each corresponding to variations on the calibration error metric.

L1 norm (Expected Calibration Error)

\text{ECE} = \frac{1}{N}\sum_i^N \|(p_i - c_i)\|

Infinity norm (Maximum Calibration Error)

\text{RMSCE} =  \max_{i} (p_i - c_i)

L2 norm (Root Mean Square Calibration Error)

\text{MCE} = \frac{1}{N}\sum_i^N (p_i - c_i)^2

Where p_i is the top-1 prediction accuracy in bin i and c_i is the average confidence of predictions in bin i.

Note

L2-norm debiasing is not yet supported.

Parameters
  • n_bins (int) – Number of bins to use when computing probabilites and accuracies.

  • norm (str) – Norm used to compare empirical and expected probability bins. Defaults to “l1”, or Expected Calibration Error.

  • debias – Applies debiasing term, only implemented for l2 norm. Defaults to True.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes calibration error across all confidences and accuracies.

Returns

Calibration error across previously collected examples.

Return type

Tensor

update(preds, target)[source]

Computes top-level confidences and accuracies for the input probabilites and appends them to internal state.

Parameters
  • preds (Tensor) – Model output probabilities.

  • target (Tensor) – Ground-truth target class labels.

Return type

None

CohenKappa

class torchmetrics.CohenKappa(num_classes, weights=None, threshold=0.5, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Calculates Cohen’s kappa score that measures inter-annotator agreement. It is defined as

\kappa = (p_o - p_e) / (1 - p_e)

where p_o is the empirical probability of agreement and p_e is the expected agreement when both annotators assign labels randomly. Note that p_e is estimated using a per-annotator empirical prior over the class labels.

Works with binary, multiclass, and multilabel data. Accepts probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.

Forward accepts
  • preds (float or long tensor): (N, ...) or (N, C, ...) where C is the number of classes

  • target (long tensor): (N, ...)

If preds and target are the same shape and preds is a float tensor, we use the self.threshold argument to convert into integer labels. This is the case for binary and multi-label probabilities or logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on dim=1.

Parameters
  • num_classes (int) – Number of classes in the dataset.

  • weights (Optional[str]) – Weighting type to calculate the score. Choose from - None or 'none': no weighting - 'linear': linear weighting - 'quadratic': quadratic weighting

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example

>>> from torchmetrics import CohenKappa
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> cohenkappa = CohenKappa(num_classes=2)
>>> cohenkappa(preds, target)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes cohen kappa score.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

ConfusionMatrix

class torchmetrics.ConfusionMatrix(num_classes, normalize=None, threshold=0.5, multilabel=False, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes the confusion matrix. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that additional dimensions will be flattened.

Forward accepts

  • preds (float or long tensor): (N, ...) or (N, C, ...) where C is the number of classes

  • target (long tensor): (N, ...)

If preds and target are the same shape and preds is a float tensor, we use the self.threshold argument to convert into integer labels. This is the case for binary and multi-label probabilities or logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on dim=1.

If working with multilabel data, setting the is_multilabel argument to True will make sure that a confusion matrix gets calculated per label.

Parameters
  • num_classes (int) – Number of classes in the dataset.

  • normalize (Optional[str]) –

    Normalization mode for confusion matrix. Choose from

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • multilabel (bool) – determines if data is multilabel or not.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example (binary data):
>>> from torchmetrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
        [1., 1.]])
Example (multiclass data):
>>> target = torch.tensor([2, 1, 0, 0])
>>> preds = torch.tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(num_classes=3)
>>> confmat(preds, target)
tensor([[1., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
Example (multilabel data):
>>> target = torch.tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(num_classes=3, multilabel=True)
>>> confmat(preds, target)  
tensor([[[1., 0.], [0., 1.]],
        [[1., 0.], [1., 0.]],
        [[0., 1.], [0., 1.]]])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes confusion matrix.

Return type

Tensor

Returns

If multilabel=False this will be a [n_classes, n_classes] tensor and if multilabel=True this will be a [n_classes, 2, 2] tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

F1

class torchmetrics.F1(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes F1 metric. F1 metrics correspond to a harmonic mean of the precision and recall scores.

Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.

Forward accepts

  • preds (float or long tensor): (N, ...) or (N, C, ...) where C is the number of classes

  • target (long tensor): (N, ...)

If preds and target are the same shape and preds is a float tensor, we use the self.threshold argument. This is the case for binary and multi-label logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on dim=1.

Parameters
  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • average (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... (see Input types) as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs (see Input types) are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Example

>>> from torchmetrics import F1
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1(num_classes=3)
>>> f1(preds, target)
tensor(0.3333)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

FBeta

class torchmetrics.FBeta(num_classes=None, beta=1.0, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes F-score, specifically:

F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}}
{(\beta^2 * \text{precision}) + \text{recall}}

Where \beta is some positive real factor. Works with binary, multiclass, and multilabel data. Accepts logit scores or probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.

Forward accepts

  • preds (float or long tensor): (N, ...) or (N, C, ...) where C is the number of classes

  • target (long tensor): (N, ...)

If preds and target are the same shape and preds is a float tensor, we use the self.threshold argument to convert into integer labels. This is the case for binary and multi-label logits and probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on dim=1.

Parameters
  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • beta (float) – Beta coefficient in the F measure.

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • average (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

    Note

    If 'none' and a given class doesn’t occur in the preds or target, the value for the class will be nan.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... (see Input types) as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs (see Input types) are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If average is none of "micro", "macro", "weighted", "none", None.

Example

>>> from torchmetrics import FBeta
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f_beta = FBeta(num_classes=3, beta=0.5)
>>> f_beta(preds, target)
tensor(0.3333)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes fbeta over state.

Return type

Tensor

HammingDistance

class torchmetrics.HammingDistance(threshold=0.5, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the average Hamming distance (also known as Hamming loss) between targets and predictions:

\text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}})

Where y is a tensor of target values, \hat{y} is a tensor of predictions, and \bullet_{il} refers to the l-th label of the i-th sample of that tensor.

This is the same as 1-accuracy for binary data, while for all other types of inputs it treats each possible label separately - meaning that, for example, multi-class data is treated as if it were multi-label.

Accepts all input types listed in Input types.

Parameters
  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the all gather.

Raises

ValueError – If threshold is not between 0 and 1.

Example

>>> from torchmetrics import HammingDistance
>>> target = torch.tensor([[0, 1], [1, 1]])
>>> preds = torch.tensor([[0, 1], [0, 1]])
>>> hamming_distance = HammingDistance()
>>> hamming_distance(preds, target)
tensor(0.2500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes hamming distance based on inputs passed in to update previously.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets. See Input types for more information on input types.

Parameters
  • preds (Tensor) – Predictions from model (probabilities, logits or labels)

  • target (Tensor) – Ground truth labels

Return type

None

Hinge

class torchmetrics.Hinge(squared=False, multiclass_mode=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the mean Hinge loss, typically used for Support Vector Machines (SVMs). In the binary case it is defined as:

\text{Hinge loss} = \max(0, 1 - y \times \hat{y})

Where y \in {-1, 1} is the target, and \hat{y} \in \mathbb{R} is the prediction.

In the multi-class case, when multiclass_mode=None (default), multiclass_mode=MulticlassMode.CRAMMER_SINGER or multiclass_mode="crammer-singer", this metric will compute the multi-class hinge loss defined by Crammer and Singer as:

\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)

Where y \in {0, ..., \mathrm{C}} is the target class (where \mathrm{C} is the number of classes), and \hat{y} \in \mathbb{R}^\mathrm{C} is the predicted output per class.

In the multi-class case when multiclass_mode=MulticlassMode.ONE_VS_ALL or multiclass_mode='one-vs-all', this metric will use a one-vs-all approach to compute the hinge loss, giving a vector of C outputs where each entry pits that class against all remaining classes.

This metric can optionally output the mean of the squared hinge loss by setting squared=True

Only accepts inputs with preds shape of (N) (binary) or (N, C) (multi-class) and target shape of (N).

Parameters
  • squared (bool) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss (default).

  • multiclass_mode (Union[str, MulticlassMode, None]) – Which approach to use for multi-class inputs (has no effect in the binary case). None (default), MulticlassMode.CRAMMER_SINGER or "crammer-singer", uses the Crammer Singer multi-class hinge loss. MulticlassMode.ONE_VS_ALL or "one-vs-all" computes the hinge loss in a one-vs-all fashion.

Raises

ValueError – If multiclass_mode is not: None, MulticlassMode.CRAMMER_SINGER, "crammer-singer", MulticlassMode.ONE_VS_ALL or "one-vs-all".

Example (binary case):
>>> import torch
>>> from torchmetrics import Hinge
>>> target = torch.tensor([0, 1, 1])
>>> preds = torch.tensor([-2.2, 2.4, 0.1])
>>> hinge = Hinge()
>>> hinge(preds, target)
tensor(0.3000)
Example (default / multiclass case):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge()
>>> hinge(preds, target)
tensor(2.9000)
Example (multiclass example, one vs all mode):
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = Hinge(multiclass_mode="one-vs-all")
>>> hinge(preds, target)
tensor([2.2333, 1.5000, 1.2333])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Tensor

update(preds, target)[source]

Override this method to update the state variables of your metric class.

Return type

None

IoU

class torchmetrics.IoU(num_classes, ignore_index=None, absent_score=0.0, threshold=0.5, reduction='elementwise_mean', compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes Intersection over union, or Jaccard index:

J(A,B) = \frac{|A\cap B|}{|A\cup B|}

Where: A and B are both tensors of the same size, containing integer class values. They may be subject to conversion from input data (see description below). Note that it is different from box IoU.

Works with binary, multiclass and multi-label data. Accepts probabilities from a model output or integer class values in prediction. Works with multi-dimensional preds and target.

Forward accepts

  • preds (float or long tensor): (N, ...) or (N, C, ...) where C is the number of classes

  • target (long tensor): (N, ...)

If preds and target are the same shape and preds is a float tensor, we use the self.threshold argument to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on dim=1.

Parameters
  • num_classes (int) – Number of classes in the dataset.

  • ignore_index (Optional[int]) – optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the range [0, num_classes-1]. By default, no index is ignored, and all classes are used.

  • absent_score (float) – score to use for an individual class, if no instances of the class index were present in pred AND no instances of the class index were present in target. For example, if we have 3 classes, [0, 0] for pred, and [0, 2] for target, then class 1 would be assigned the absent_score.

  • threshold (float) – Threshold value for binary or multi-label probabilities.

  • reduction (str) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none': no reduction will be applied

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example

>>> from torchmetrics import IoU
>>> target = torch.randint(0, 2, (10, 25, 25))
>>> pred = torch.tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> iou = IoU(num_classes=2)
>>> iou(pred, target)
tensor(0.9660)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes intersection over union (IoU)

Return type

Tensor

KLDivergence

class torchmetrics.KLDivergence(log_prob=False, reduction='mean', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the KL divergence:

D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}

Where P and Q are probability distributions where P usually represents a distribution over data and Q is often a prior or approximation of P. It should be noted that the KL divergence is a non-symetrical metric i.e. D_{KL}(P||Q) \neq D_{KL}(Q||P).

Parameters
  • p – data distribution with shape [N, d]

  • q – prior or approximate distribution with shape [N, d]

  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1

  • reduction (Optional[str]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

Raises
  • TypeError – If log_prob is not an bool

  • ValueError – If reduction is not one of 'mean', 'sum', 'none' or None

Note

Half precision is only support on GPU for this metric

Example

>>> import torch
>>> from torchmetrics.functional import kl_divergence
>>> p = torch.tensor([[0.36, 0.48, 0.16]])
>>> q = torch.tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence(p, q)
tensor(0.0853)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Tensor

update(p, q)[source]

Override this method to update the state variables of your metric class.

Return type

None

MatthewsCorrcoef

class torchmetrics.MatthewsCorrcoef(num_classes, threshold=0.5, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculates Matthews correlation coefficient that measures the general correlation or quality of a classification. In the binary case it is defined as:

MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}}

where TP, TN, FP and FN are respectively the true postitives, true negatives, false positives and false negatives. Also works in the case of multi-label or multi-class input.

Note

This metric produces a multi-dimensional output, so it can not be directly logged.

Forward accepts

  • preds (float or long tensor): (N, ...) or (N, C, ...) where C is the number of classes

  • target (long tensor): (N, ...)

If preds and target are the same shape and preds is a float tensor, we use the self.threshold argument to convert into integer labels. This is the case for binary and multi-label probabilities.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on dim=1.

Parameters
  • num_classes (int) – Number of classes in the dataset.

  • threshold (float) – Threshold value for binary or multi-label probabilites. default: 0.5

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Example

>>> from torchmetrics import MatthewsCorrcoef
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> matthews_corrcoef = MatthewsCorrcoef(num_classes=2)
>>> matthews_corrcoef(preds, target)
tensor(0.5774)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes matthews correlation coefficient.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Precision

class torchmetrics.Precision(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Precision:

\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}

Where \text{TP} and \text{FP} represent the number of true positives and false positives respecitively. With the use of top_k parameter, this metric can generalize to Precision@K.

The reduction method (how the precision scores are aggregated) is controlled by the average parameter, and additionally by the mdmc_average parameter in the multi-dimensional multi-class case. Accepts all inputs listed in Input types.

Parameters
  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • average (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... (see Input types) as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs (see Input types) are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If average is none of "micro", "macro", "weighted", "samples", "none", None.

Example

>>> from torchmetrics import Precision
>>> preds  = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> precision = Precision(average='macro', num_classes=3)
>>> precision(preds, target)
tensor(0.1667)
>>> precision = Precision(average='micro')
>>> precision(preds, target)
tensor(0.2500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the precision score based on inputs passed in to update previously.

Return type

Tensor

Returns

The shape of the returned tensor depends on the average parameter

  • If average in ['micro', 'macro', 'weighted', 'samples'], a one-element tensor will be returned

  • If average in ['none', None], the shape will be (C,), where C stands for the number of classes

PrecisionRecallCurve

class torchmetrics.PrecisionRecallCurve(num_classes=None, pos_label=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass) tensor with probabilities, where C is the number of classes.

  • target (long tensor): (N, ...) or (N, C, ...) with integer labels

Parameters
  • num_classes (Optional[int]) – integer with number of classes for multi-label and multiclass problems. Should be set to None for binary problems

  • pos_label (Optional[int]) – integer determining the positive class. Default is None which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example (binary case):
>>> from torchmetrics import PrecisionRecallCurve
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = PrecisionRecallCurve(pos_label=1)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = PrecisionRecallCurve(num_classes=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision   
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
 tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the precision-recall curve.

Return type

Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]

Returns

3-element tuple containing

precision:

tensor where element i is the precision of predictions with score >= thresholds[i] and the last element is 1. If multiclass, this is a list of such tensors, one for each class.

recall:

tensor where element i is the recall of predictions with score >= thresholds[i] and the last element is 0. If multiclass, this is a list of such tensors, one for each class.

thresholds:

Thresholds used for computing precision/recall scores

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Recall

class torchmetrics.Recall(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Recall:

\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}

Where \text{TP} and \text{FN} represent the number of true positives and false negatives respecitively. With the use of top_k parameter, this metric can generalize to Recall@K.

The reduction method (how the recall scores are aggregated) is controlled by the average parameter, and additionally by the mdmc_average parameter in the multi-dimensional multi-class case. Accepts all inputs listed in Input types.

Parameters
  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • average (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... (see Input types) as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs (see Input types) are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If average is none of "micro", "macro", "weighted", "samples", "none", None.

Example

>>> from torchmetrics import Recall
>>> preds  = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> recall = Recall(average='macro', num_classes=3)
>>> recall(preds, target)
tensor(0.3333)
>>> recall = Recall(average='micro')
>>> recall(preds, target)
tensor(0.2500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the recall score based on inputs passed in to update previously.

Return type

Tensor

Returns

The shape of the returned tensor depends on the average parameter

  • If average in ['micro', 'macro', 'weighted', 'samples'], a one-element tensor will be returned

  • If average in ['none', None], the shape will be (C,), where C stands for the number of classes

ROC

class torchmetrics.ROC(num_classes=None, pos_label=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.

Forward accepts

  • preds (float tensor): (N, ...) (binary) or (N, C, ...) (multiclass/multilabel) tensor with probabilities, where C is the number of classes/labels.

  • target (long tensor): (N, ...) or (N, C, ...) with integer labels

Note

If either the positive class or negative class is completly missing in the target tensor, the roc values are not well defined in this case and a tensor of zeros will be returned (either fpr or tpr depending on what class is missing) together with an warning.

Parameters
  • num_classes (Optional[int]) – integer with number of classes for multi-label and multiclass problems. Should be set to None for binary problems

  • pos_label (Optional[int]) – integer determining the positive class. Default is None which for binary problem is translate to 1. For multiclass problems this argument should not be set as we iteratively change it in the range [0,num_classes-1]

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Example (binary case):
>>> from torchmetrics import ROC
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 1])
>>> roc = ROC(pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05],
...                      [0.05, 0.05, 0.05, 0.75]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> roc = ROC(num_classes=4)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])]
>>> thresholds 
[tensor([1.7500, 0.7500, 0.0500]),
 tensor([1.7500, 0.7500, 0.0500]),
 tensor([1.7500, 0.7500, 0.0500]),
 tensor([1.7500, 0.7500, 0.0500])]
Example (multilabel case):
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
...                      [0.3584, 0.7576, 0.1183],
...                      [0.2286, 0.3468, 0.1338],
...                      [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> roc = ROC(num_classes=3, pos_label=1)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr 
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
 tensor([0., 0., 0., 1., 1.]),
 tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr  
[tensor([0., 0., 1., 1., 1.]),
 tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
 tensor([0., 1., 1., 1., 1.])]
>>> thresholds 
[tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]),
 tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]),
 tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the receiver operating characteristic.

Return type

Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]

Returns

3-element tuple containing

fpr:

tensor with false positive rates. If multiclass, this is a list of such tensors, one for each class.

tpr:

tensor with true positive rates. If multiclass, this is a list of such tensors, one for each class.

thresholds:

thresholds used for computing false- and true postive rates

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Specificity

class torchmetrics.Specificity(num_classes=None, threshold=0.5, average='micro', mdmc_average=None, ignore_index=None, top_k=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Specificity:

\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}

Where \text{TN} and \text{FP} represent the number of true negatives and false positives respecitively. With the use of top_k parameter, this metric can generalize to Specificity@K.

The reduction method (how the specificity scores are aggregated) is controlled by the average parameter, and additionally by the mdmc_average parameter in the multi-dimensional multi-class case. Accepts all inputs listed in Input types.

Parameters
  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • threshold (float) – Threshold probability value for transforming probability predictions to binary (0,1) predictions, in the case of binary or multi-label inputs.

  • average (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tn + fp).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... (see Input types) as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs (see Input types) are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k (Optional[int]) –

    Number of highest probability entries for each sample to convert to 1s - relevant only for inputs with probability predictions. If this parameter is set for multi-label inputs, it will take precedence over threshold. For (multi-dim) multi-class inputs, this parameter defaults to 1.

    Should be left unset (None) for inputs with label predictions.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises

ValueError – If average is none of "micro", "macro", "weighted", "samples", "none", None.

Example

>>> from torchmetrics import Specificity
>>> preds  = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> specificity = Specificity(average='macro', num_classes=3)
>>> specificity(preds, target)
tensor(0.6111)
>>> specificity = Specificity(average='micro')
>>> specificity(preds, target)
tensor(0.6250)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the specificity score based on inputs passed in to update previously.

Return type

Tensor

Returns

The shape of the returned tensor depends on the average parameter

  • If average in ['micro', 'macro', 'weighted', 'samples'], a one-element tensor will be returned

  • If average in ['none', None], the shape will be (C,), where C stands for the number of classes

StatScores

class torchmetrics.StatScores(threshold=0.5, top_k=None, reduce='micro', num_classes=None, ignore_index=None, mdmc_reduce=None, multiclass=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the number of true positives, false positives, true negatives, false negatives. Related to Type I and Type II errors and the confusion matrix.

The reduction method (how the statistics are aggregated) is controlled by the reduce parameter, and additionally by the mdmc_reduce parameter in the multi-dimensional multi-class case.

Accepts all inputs listed in Input types.

Parameters
  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • top_k (Optional[int]) –

    Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • reduce (str) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer.

    • 'macro': Counts the statistics for each class separately (over all samples). Each statistic is represented by a (C,) tensor. Requires num_classes to be set.

    • 'samples': Counts the statistics for each sample separately (over all classes). Each statistic is represented by a (N, ) 1d tensor.

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_reduce.

  • num_classes (Optional[int]) – Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data.

  • ignore_index (Optional[int]) – Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and reduce='macro', the class statistics for the ignored class will all be returned as -1.

  • mdmc_reduce (Optional[str]) –

    Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class (see Input types for the definition of input types).

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then the outputs are concatenated together. In each sample the extra axes ... are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as the N axis for that sample.

    • 'global': In this case the N and ... dimensions of the inputs are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the reduce parameter applies as usual.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter’s documentation section for a more detailed explanation and examples.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Raises
  • ValueError – If reduce is none of "micro", "macro" or "samples".

  • ValueError – If mdmc_reduce is none of None, "samplewise", "global".

  • ValueError – If reduce is set to "macro" and num_classes is not provided.

  • ValueError – If num_classes is set and ignore_index is not in the range 0 <= ignore_index < num_classes.

Example

>>> from torchmetrics.classification import StatScores
>>> preds  = torch.tensor([1, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> stat_scores = StatScores(reduce='macro', num_classes=3)
>>> stat_scores(preds, target)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])
>>> stat_scores = StatScores(reduce='micro')
>>> stat_scores(preds, target)
tensor([2, 2, 6, 2, 4])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes the stat scores based on inputs passed in to update previously.

Return type

Tensor

Returns

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the reduce and mdmc_reduce (in case of multi-dimensional multi-class data) parameters:

  • If the data is not multi-dimensional multi-class, then

    • If reduce='micro', the shape will be (5, )

    • If reduce='macro', the shape will be (C, 5), where C stands for the number of classes

    • If reduce='samples', the shape will be (N, 5), where N stands for the number of samples

  • If the data is multi-dimensional multi-class and mdmc_reduce='global', then

    • If reduce='micro', the shape will be (5, )

    • If reduce='macro', the shape will be (C, 5)

    • If reduce='samples', the shape will be (N*X, 5), where X stands for the product of sizes of all “extra” dimensions of the data (i.e. all dimensions except for C and N)

  • If the data is multi-dimensional multi-class and mdmc_reduce='samplewise', then

    • If reduce='micro', the shape will be (N, 5)

    • If reduce='macro', the shape will be (N, C, 5)

    • If reduce='samples', the shape will be (N, X, 5)

update(preds, target)[source]

Update state with predictions and targets. See Input types for more information on input types.

Parameters
  • preds (Tensor) – Predictions from model (probabilities, logits or labels)

  • target (Tensor) – Ground truth values

Return type

None

Image Metrics

Image quality metrics can be used to access the quality of synthetic generated images from machine learning algorithms such as Generative Adverserial Networks (GANs).

FID

class torchmetrics.FID(feature=2048, compute_on_step=False, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculates Fréchet inception distance (FID) which is used to access the quality of generated images. Given by

FID = |\mu - \mu_w| + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})

where \mathcal{N}(\mu, \Sigma) is the multivariate normal distribution estimated from Inception v3 [1] features calculated on real life images and \mathcal{N}(\mu_w, \Sigma_w) is the multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images. The metric was originally proposed in [1].

Using the default feature extraction (Inception v3 using the original weights from [2]), the input is expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flag real determines if the images should update the statistics of the real distribution or the fake distribution.

Note

using this metrics requires you to have scipy install. Either install as pip install torchmetrics[image] or pip install scipy

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

Note

the forward method can be used but compute_on_step is disabled by default (oppesit of all other metrics) as this metric does not really make sense to calculate on a single batch. This means that by default forward will just call update underneat.

Parameters
  • feature (Union[int, Module]) –

    Either an integer or nn.Module:

    • an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an [N,d] matrix where N is the batch size and d is the feature size.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], List[Tensor]]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

References

[1] Rethinking the Inception Architecture for Computer Vision Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna https://arxiv.org/abs/1512.00567

[2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter https://arxiv.org/abs/1706.08500

Raises
  • ValueError – If feature is set to an int (default settings) and torch-fidelity is not installed

  • ValueError – If feature is set to an int not in [64, 192, 768, 2048]

  • TypeError – If feature is not an str, int or torch.nn.Module

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics import FID
>>> fid = FID(feature=64)  
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)  
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)  
>>> fid.update(imgs_dist1, real=True)  
>>> fid.update(imgs_dist2, real=False)  
>>> fid.compute()  
tensor(12.7202)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate FID score based on accumulated extracted features from the two distributions.

Return type

Tensor

update(imgs, real)[source]

Update the state with extracted features.

Parameters
  • imgs (Tensor) – tensor with images feed to the feature extractor

  • real (bool) – bool indicating if imgs belong to the real or the fake distribution

Return type

None

IS

class torchmetrics.IS(feature='logits_unbiased', splits=10, compute_on_step=False, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculates the Inception Score (IS) which is used to access how realistic generated images are. It is defined as

IS = exp(\mathbb{E}_x KL(p(y | x ) || p(y)))

where KL(p(y | x) || p(y)) is the KL divergence between the conditional distribution p(y|x) and the margianl distribution p(y). Both the conditional and marginal distribution is calculated from features extracted from the images. The score is calculated on random splits of the images such that both a mean and standard deviation of the score are returned. The metric was originally proposed in [1].

Using the default feature extraction (Inception v3 using the original weights from [2]), the input is expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images will be resized to 299 x 299 which is the size of the original training data.

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

Note

the forward method can be used but compute_on_step is disabled by default (oppesit of all other metrics) as this metric does not really make sense to calculate on a single batch. This means that by default forward will just call update underneat.

Parameters
  • feature (Union[str, int, Module]) –

    Either an str, integer or nn.Module:

    • an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an [N,d] matrix where N is the batch size and d is the feature size.

  • splits (int) – integer determining how many splits the inception score calculation should be split among

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], List[Tensor]]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

References

[1] Improved Techniques for Training GANs Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen https://arxiv.org/abs/1606.03498

[2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter https://arxiv.org/abs/1706.08500

Raises
  • ValueError – If feature is set to an str or int and torch-fidelity is not installed

  • ValueError – If feature is set to an str or int and not one of [‘logits_unbiased’, 64, 192, 768, 2048]

  • TypeError – If feature is not an str, int or torch.nn.Module

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics import IS
>>> inception = IS()  
>>> # generate some images
>>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)  
>>> inception.update(imgs)  
>>> inception.compute()  
(tensor(1.0569), tensor(0.0113))

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Tuple[Tensor, Tensor]

update(imgs)[source]

Update the state with extracted features.

Parameters

imgs (Tensor) – tensor with images feed to the feature extractor

Return type

None

KID

class torchmetrics.KID(feature=2048, subsets=100, subset_size=1000, degree=3, gamma=None, coef=1.0, compute_on_step=False, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculates Kernel Inception Distance (KID) which is used to access the quality of generated images. Given by

KID = MMD(f_{real}, f_{fake})^2

where MMD is the maximum mean discrepancy and I_{real}, I_{fake} are extracted features from real and fake images, see [1] for more details. In particular, calculating the MMD requires the evaluation of a polynomial kernel function k

k(x,y) = (\gamma * x^T y + coef)^{degree}

which controls the distance between two features. In practise the MMD is calculated over a number of subsets to be able to both get the mean and standard deviation of KID.

Using the default feature extraction (Inception v3 using the original weights from [2]), the input is expected to be mini-batches of 3-channel RGB images of shape (3 x H x W) with dtype uint8. All images will be resized to 299 x 299 which is the size of the original training data.

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

Note

the forward method can be used but compute_on_step is disabled by default (oppesit of all other metrics) as this metric does not really make sense to calculate on a single batch. This means that by default forward will just call update underneat.

Parameters
  • feature (Union[str, int, Module]) –

    Either an str, integer or nn.Module:

    • an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an [N,d] matrix where N is the batch size and d is the feature size.

  • subsets (int) – Number of subsets to calculate the mean and standard deviation scores over

  • subset_size (int) – Number of randomly picked samples in each subset

  • degree (int) – Degree of the polynomial kernel function

  • gamma (Optional[float]) – Scale-length of polynomial kernel. If set to None will be automatically set to the feature size

  • coef (float) – Bias term in the polynomial kernel.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

References

[1] Demystifying MMD GANs Mikołaj Bińkowski, Danica J. Sutherland, Michael Arbel, Arthur Gretton https://arxiv.org/abs/1801.01401

[2] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, Sepp Hochreiter https://arxiv.org/abs/1706.08500

Raises
  • ValueError – If feature is set to an int (default settings) and torch-fidelity is not installed

  • ValueError – If feature is set to an int not in [64, 192, 768, 2048]

  • ValueError – If subsets is not an integer larger than 0

  • ValueError – If subset_size is not an integer larger than 0

  • ValueError – If degree is not an integer larger than 0

  • ValueError – If gamma is niether None or a float larger than 0

  • ValueError – If coef is not an float larger than 0

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics import KID
>>> kid = KID(subset_size=50)  
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)  
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)  
>>> kid.update(imgs_dist1, real=True)  
>>> kid.update(imgs_dist2, real=False)  
>>> kid_mean, kid_std = kid.compute()  
>>> print((kid_mean, kid_std))  
(tensor(0.0338), tensor(0.0025))

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate KID score based on accumulated extracted features from the two distributions. Returns a tuple of mean and standard deviation of KID scores calculated on subsets of extracted features.

Implementation inspired by Fid Score

Return type

Tuple[Tensor, Tensor]

update(imgs, real)[source]

Update the state with extracted features.

Parameters
  • imgs (Tensor) – tensor with images feed to the feature extractor

  • real (bool) – bool indicating if imgs belong to the real or the fake distribution

Return type

None

LPIPS

class torchmetrics.LPIPS(net_type='alex', reduction='mean', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

The Learned Perceptual Image Patch Similarity (LPIPS_) is used to judge the perceptual similarity between two images. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure have been shown to match human perseption well. A low LPIPS score means that image patches are perceptual similar.

Both input image patches are expected to have shape [N, 3, H, W] and be normalized to the [-1,1] range. The minimum size of H, W depends on the chosen backbone (see net_type arg).

Note

using this metrics requires you to have lpips package installed. Either install as pip install torchmetrics[image] or pip install lpips

Note

this metric is not scriptable when using torch<1.8. Please update your pytorch installation if this is a issue.

Parameters
  • net_type (str) – str indicating backbone network type to use. Choose between ‘alex’, ‘vgg’ or ‘squeeze’

  • reduction (str) – str indicating how to reduce over the batch dimension. Choose between ‘sum’ or ‘mean’.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable[[Tensor], List[Tensor]]]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Raises
  • ValueError – If lpips package is not installed

  • ValueError – If net_type is not one of "vgg", "alex" or "squeeze"

  • ValueError – If reduction is not one of "mean" or "sum"

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics import LPIPS
>>> lpips = LPIPS(net_type='vgg')
>>> img1 = torch.rand(10, 3, 100, 100)
>>> img2 = torch.rand(10, 3, 100, 100)
>>> lpips(img1, img2)
tensor([0.3566], grad_fn=<DivBackward0>)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute final perceptual similarity metric.

Return type

Tensor

update(img1, img2)[source]

Update internal states with lpips score.

Parameters
  • img1 (Tensor) – tensor with images of shape [N, 3, H, W]

  • img2 (Tensor) – tensor with images of shape [N, 3, H, W]

Return type

None

PSNR

class torchmetrics.PSNR(data_range=None, base=10.0, reduction='elementwise_mean', dim=None, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes Computes Peak Signal-to-Noise Ratio (PSNR):

\text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right)

Where \text{MSE} denotes the mean-squared-error function.

Parameters
  • data_range (Optional[float]) – the range of the data. If None, it is determined from the data (max - min). The data_range must be given when dim is not None.

  • base (float) – a base of a logarithm to use (default: 10)

  • reduction (str) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none': no reduction will be applied

  • dim (Union[int, Tuple[int, …], None]) – Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions and all batches.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises

ValueError – If dim is not None and data_range is not given.

Example

>>> from torchmetrics import PSNR
>>> psnr = PSNR()
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(preds, target)
tensor(2.5527)

Note

Half precision is only support on GPU for this metric

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute peak signal-to-noise ratio over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

SSIM

class torchmetrics.SSIM(kernel_size=(11, 11), sigma=(1.5, 1.5), reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes Structual Similarity Index Measure (SSIM).

Parameters
  • kernel_size (Sequence[int]) – size of the gaussian kernel (default: (11, 11))

  • sigma (Sequence[float]) – Standard deviation of the gaussian kernel (default: (1.5, 1.5))

  • reduction (str) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none': no reduction will be applied

  • data_range (Optional[float]) – Range of the image. If None, it is determined from the image (max - min)

  • k1 (float) – Parameter of SSIM. Default: 0.01

  • k2 (float) – Parameter of SSIM. Default: 0.03

Returns

Tensor with SSIM score

Example

>>> from torchmetrics import SSIM
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> ssim = SSIM()
>>> ssim(preds, target)
tensor(0.9219)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes explained variance over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

Detection Metrics

Object detection metrics can be used to evaluate the predicted detections with given groundtruth detections on images.

MAP

class torchmetrics.MAP(class_metrics=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)for object detection predictions. Optionally, the mAP and mAR values can be calculated per class.

Predicted boxes and targets have to be in Pascal VOC format (xmin-top left, ymin-top left, xmax-bottom right, ymax-bottom right). See the update() method for more information about the input format to this metric.

Note

This metric is a wrapper for the pycocotools, which is a standard implementation for the mAP metric for object detection. Using this metric therefore requires you to have pycocotools installed. Please install with pip install pycocotools or pip install torchmetrics[detection].

Note

This metric requires you to have torchvision version 0.8.0 or newer installed (with corresponding version 1.7.0 of torch or newer). Please install with pip install torchvision or pip install torchmetrics[detection].

Note

As the pycocotools library cannot deal with tensors directly, all results have to be transfered to the CPU, this might have an performance impact on your training.

Parameters
  • class_metrics (bool) – Option to enable per-class metrics for mAP and mAR_100. Has a performance impact. default: False

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Raises
  • ImportError – If pycocotools is not installed

  • ImportError – If torchvision is not installed or version installed is lower than 0.8.0

  • ValueError – If class_metrics is not a boolean

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) scores. All detections added in the update() method are included.

Note

Main map score is calculated with @[ IoU=0.50:0.95 | area=all | maxDets=100 ]

Return type

dict

Returns

dict containing

  • map: torch.Tensor

  • map_50: torch.Tensor

  • map_75: torch.Tensor

  • map_small: torch.Tensor

  • map_medium: torch.Tensor

  • map_large: torch.Tensor

  • mar_1: torch.Tensor

  • mar_10: torch.Tensor

  • mar_100: torch.Tensor

  • mar_small: torch.Tensor

  • mar_medium: torch.Tensor

  • mar_large: torch.Tensor

  • map_per_class: torch.Tensor (-1 if class metrics are disabled)

  • mar_100_per_class: torch.Tensor (-1 if class metrics are disabled)

update(preds, target)[source]

Add detections and groundtruth to the metric.

Parameters
  • preds (List[Dict[str, Tensor]]) – A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image):

  • ``boxes`` (-) – torch.FloatTensor of shape [num_boxes, 4] containing num_boxes detection boxes of the format [xmin, ymin, xmax, ymax] in absolute image coordinates.

  • ``scores`` (-) – torch.FloatTensor of shape [num_boxes] containing detection scores for the boxes.

  • ``labels`` (-) – torch.IntTensor of shape [num_boxes] containing 0-indexed detection classes for the boxes.

  • target (List[Dict[str, Tensor]]) – A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image):

  • ``boxes`` – torch.FloatTensor of shape [num_boxes, 4] containing num_boxes groundtruth boxes of the format [xmin, ymin, xmax, ymax] in absolute image coordinates.

  • ``labels`` – torch.IntTensor of shape [num_boxes] containing 1-indexed groundtruth classes for the boxes.

Raises
  • ValueError – If preds is not of type List[Dict[str, torch.Tensor]]

  • ValueError – If target is not of type List[Dict[str, torch.Tensor]]

  • ValueError – If preds and target are not of the same length

  • ValueError – If any of preds.boxes, preds.scores and preds.labels are not of the same length

  • ValueError – If any of target.boxes and target.labels are not of the same length

  • ValueError – If any box is not type float and of length 4

  • ValueError – If any class is not type int and of length 1

  • ValueError – If any score is not type float and of length 1

Return type

None

Regression Metrics

CosineSimilarity

class torchmetrics.CosineSimilarity(reduction='sum', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the Cosine Similarity between targets and predictions:

cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} =
\frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}}

where y is a tensor of target values, and x is a tensor of predictions.

Forward accepts

  • preds (float tensor): (N,d)

  • target (float tensor): (N,d)

Parameters
  • reduction (str) – how to reduce over the batch dimension using ‘sum’, ‘mean’ or ‘none’ (taking the individual scores)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the all gather.

Example

>>> from torchmetrics import CosineSimilarity
>>> target = torch.tensor([[0, 1], [1, 1]])
>>> preds = torch.tensor([[0, 1], [0, 1]])
>>> cosine_similarity = CosineSimilarity(reduction = 'mean')
>>> cosine_similarity(preds, target)
tensor(0.8536)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Tensor

update(preds, target)[source]

Update metric states with predictions and targets.

Parameters
  • preds (Tensor) – Predicted tensor with shape (N,d)

  • target (Tensor) – Ground truth tensor with shape (N,d)

Return type

None

ExplainedVariance

class torchmetrics.ExplainedVariance(multioutput='uniform_average', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes explained variance:

\text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)}

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

Forward accepts

  • preds (float tensor): (N,) or (N, ...) (multioutput)

  • target (long tensor): (N,) or (N, ...) (multioutput)

In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.

Parameters
  • multioutput (str) –

    Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is ‘uniform_average’.):

    • ’raw_values’ returns full set of scores

    • ’uniform_average’ scores are uniformly averaged

    • ’variance_weighted’ scores are weighted by their individual variances

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises

ValueError – If multioutput is not one of "raw_values", "uniform_average" or "variance_weighted".

Example

>>> from torchmetrics import ExplainedVariance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance = ExplainedVariance()
>>> explained_variance(preds, target)
tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance = ExplainedVariance(multioutput='raw_values')
>>> explained_variance(preds, target)
tensor([0.9677, 1.0000])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes explained variance over state.

Return type

Union[Tensor, Sequence[Tensor]]

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

MeanAbsoluteError

class torchmetrics.MeanAbsoluteError(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Mean Absolute Error (MAE):

\text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} |

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example

>>> from torchmetrics import MeanAbsoluteError
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> mean_absolute_error = MeanAbsoluteError()
>>> mean_absolute_error(preds, target)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes mean absolute error over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

MeanAbsolutePercentageError

class torchmetrics.MeanAbsolutePercentageError(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Mean Absolute Percentage Error (MAPE):

\text{MAPE} = \frac{1}{n}\sum_1^n\frac{|   y_i - \hat{y_i} |}{\max(\epsilon, y_i)}

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Note

The epsilon value is taken from scikit-learn’s implementation of MAPE.

Note

MAPE output is a non-negative floating point. Best result is 0.0 . But it is important to note that, bad predictions, can lead to arbitarily large values. Especially when some target values are close to 0. This MAPE implementation returns a very large number instead of inf.

Example

>>> from torchmetrics import MeanAbsolutePercentageError
>>> target = torch.tensor([1, 10, 1e6])
>>> preds = torch.tensor([0.9, 15, 1.2e6])
>>> mean_abs_percentage_error = MeanAbsolutePercentageError()
>>> mean_abs_percentage_error(preds, target)
tensor(0.2667)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes mean absolute percentage error over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

MeanSquaredError

class torchmetrics.MeanSquaredError(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, squared=True)[source]

Computes mean squared error (MSE):

\text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • squared (bool) – If True returns MSE value, if False returns RMSE value.

Example

>>> from torchmetrics import MeanSquaredError
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> mean_squared_error = MeanSquaredError()
>>> mean_squared_error(preds, target)
tensor(0.8750)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes mean squared error over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

MeanSquaredLogError

class torchmetrics.MeanSquaredLogError(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes mean squared logarithmic error (MSLE):

\text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example

>>> from torchmetrics import MeanSquaredLogError
>>> target = torch.tensor([2.5, 5, 4, 8])
>>> preds = torch.tensor([3, 5, 2.5, 7])
>>> mean_squared_log_error = MeanSquaredLogError()
>>> mean_squared_log_error(preds, target)
tensor(0.0397)

Note

Half precision is only support on GPU for this metric

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Compute mean squared logarithmic error over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

PearsonCorrcoef

class torchmetrics.PearsonCorrcoef(compute_on_step=True, dist_sync_on_step=False, process_group=None)[source]

Computes Pearson Correlation Coefficient:

P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y}

Where y is a tensor of target values, and x is a tensor of predictions.

Forward accepts

  • preds (float tensor): (N,)

  • target``(float tensor): ``(N,)

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example

>>> from torchmetrics import PearsonCorrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson = PearsonCorrcoef()
>>> pearson(preds, target)
tensor(0.9849)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes pearson correlation coefficient over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

R2Score

class torchmetrics.R2Score(num_outputs=1, adjusted=0, multioutput='uniform_average', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes r2 score also known as R2 Score_Coefficient Determination:

R^2 = 1 - \frac{SS_{res}}{SS_{tot}}

where SS_{res}=\sum_i (y_i - f(x_i))^2 is the sum of residual squares, and SS_{tot}=\sum_i (y_i - \bar{y})^2 is total sum of squares. Can also calculate adjusted r2 score given by

R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}

where the parameter k (the number of independent regressors) should be provided as the adjusted argument.

Forward accepts

  • preds (float tensor): (N,) or (N, M) (multioutput)

  • target (float tensor): (N,) or (N, M) (multioutput)

In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.

Parameters
  • num_outputs (int) – Number of outputs in multioutput setting (default is 1)

  • adjusted (int) – number of independent regressors for calculating adjusted r2 score. Default 0 (standard r2 score).

  • multioutput (str) –

    Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is 'uniform_average'.):

    • 'raw_values' returns full set of scores

    • 'uniform_average' scores are uniformly averaged

    • 'variance_weighted' scores are weighted by their individual variances

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Raises
  • ValueError – If adjusted parameter is not an integer larger or equal to 0.

  • ValueError – If multioutput is not one of "raw_values", "uniform_average" or "variance_weighted".

Example

>>> from torchmetrics import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes r2 score over the metric states.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

SpearmanCorrcoef

class torchmetrics.SpearmanCorrcoef(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes spearmans rank correlation coefficient.

where rg_x and rg_y are the rank associated to the variables x and y. Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Example

>>> from torchmetrics import SpearmanCorrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> spearman = SpearmanCorrcoef()
>>> spearman(preds, target)
tensor(1.0000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes spearmans correlation coefficient.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

SymmetricMeanAbsolutePercentageError

class torchmetrics.SymmetricMeanAbsolutePercentageError(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes symmetric mean absolute percentage error (SMAPE).

\text{SMAPE} = \frac{2}{n}\sum_1^n max(\frac{|   y_i - \hat{y_i} |}{| y_i | + | \hat{y_i} |, \epsilon})

Where y is a tensor of target values, and \hat{y} is a tensor of predictions.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

Note

The epsilon value is taken from scikit-learn’s implementation of SMAPE.

Note

SMAPE output is a non-negative floating point between 0 and 1. Best result is 0.0 .

Example

>>> from torchmetrics import SymmetricMeanAbsolutePercentageError
>>> target = torch.tensor([1, 10, 1e6])
>>> preds = torch.tensor([0.9, 15, 1.2e6])
>>> smape = SymmetricMeanAbsolutePercentageError()
>>> smape(preds, target)
tensor(0.2290)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Computes mean absolute percentage error over state.

Return type

Tensor

update(preds, target)[source]

Update state with predictions and targets.

Parameters
  • preds (Tensor) – Predictions from model

  • target (Tensor) – Ground truth values

Return type

None

TweedieDevianceScore

class torchmetrics.TweedieDevianceScore(power=0.0, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes the Tweedie Deviance Score between targets and predictions:

deviance\_score(\hat{y},y) =
\begin{cases}
(\hat{y} - y)^2, & \text{for }power=0\\
2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y),  & \text{for }power=1\\
2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1),  & \text{for }power=2\\
2 * (\frac{(max(y,0))^{2}}{(1 - power)(2 - power)} - \frac{y(\hat{y})^{1 - power}}{1 - power} + \frac{(\hat{y})
    ^{2 - power}}{2 - power}), & \text{otherwise}
\end{cases}

where y is a tensor of targets values, and \hat{y} is a tensor of predictions.

Forward accepts

  • preds (float tensor): (N,...)

  • targets (float tensor): (N,...)

Parameters
  • power (float) –

    • power < 0 : Extreme stable distribution. (Requires: preds > 0.)

    • power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)

    • power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)

    • 1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)

    • power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)

    • power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)

    • otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the all gather.

Example

>>> from torchmetrics import TweedieDevianceScore
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> deviance_score = TweedieDevianceScore(power=2)
>>> deviance_score(preds, targets)
tensor(1.2083)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Tensor

update(preds, targets)[source]

Update metric states with predictions and targets.

Parameters
  • preds (Tensor) – Predicted tensor with shape (N,d)

  • targets (Tensor) – Ground truth tensor with shape (N,d)

Return type

None

Retrieval

Input details

For the purposes of retrieval metrics, inputs (indexes, predictions and targets) must have the same size (N stands for the batch size) and the following types:

indexes shape

indexes dtype

preds shape

preds dtype

target shape

target dtype

long

(N,…)

float

(N,…)

long or bool

(N,…)

Note

All dimensions are flattened at the beginning, so that, for example, a tensor of shape (N, M) is treated as (N * M, ).

In Information Retrieval you have a query that is compared with a variable number of documents. For each pair (Q_i, D_j), a score is computed that measures the relevance of document D w.r.t. query Q. Documents are then sorted by score and you hope that relevant documents are scored higher. target contains the labels for the documents (relevant or not).

Since a query may be compared with a variable number of documents, we use indexes to keep track of which scores belong to the set of pairs (Q_i, D_j) having the same query Q_i.

Note

Retrieval metrics are only intended to be used globally. This means that the average of the metric over each batch can be quite different from the metric computed on the whole dataset. For this reason, we suggest to compute the metric only when all the examples has been provided to the metric. When using Pytorch Lightning, we suggest to use on_step=False and on_epoch=True in self.log or to place the metric calculation in training_epoch_end, validation_epoch_end or test_epoch_end.

>>> from torchmetrics import RetrievalMAP
>>> # functional version works on a single query at a time
>>> from torchmetrics.functional import retrieval_average_precision

>>> # the first query was compared with two documents, the second with three
>>> indexes = torch.tensor([0, 0, 1, 1, 1])
>>> preds = torch.tensor([0.8, -0.4, 1.0, 1.4, 0.0])
>>> target = torch.tensor([0, 1, 0, 1, 1])

>>> map = RetrievalMAP() # or some other retrieval metric
>>> map(preds, target, indexes=indexes)
tensor(0.6667)

>>> # the previous instruction is roughly equivalent to
>>> res = []
>>> # iterate over indexes of first and second query
>>> for indexes in ([0, 1], [2, 3, 4]):
...     res.append(retrieval_average_precision(preds[indexes], target[indexes]))
>>> torch.stack(res).mean()
tensor(0.6667)

RetrievalMAP

class torchmetrics.RetrievalMAP(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Mean Average Precision.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then MAP will be computed as the mean of the Average Precisions over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Example

>>> from torchmetrics import RetrievalMAP
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> rmap = RetrievalMAP()
>>> rmap(preds, target, indexes=indexes)
tensor(0.7917)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

RetrievalMRR

class torchmetrics.RetrievalMRR(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Mean Reciprocal Rank.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then MRR will be computed as the mean of the Reciprocal Rank over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Example

>>> from torchmetrics import RetrievalMRR
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> mrr = RetrievalMRR()
>>> mrr(preds, target, indexes=indexes)
tensor(0.7500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

RetrievalPrecision

class torchmetrics.RetrievalPrecision(empty_target_action='neg', k=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes IR Precision.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then Precision will be computed as the mean of the Precision over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Raises

ValueError – If k parameter is not None or an integer larger than 0

Example

>>> from torchmetrics import RetrievalPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalPrecision(k=2)
>>> p2(preds, target, indexes=indexes)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

RetrievalRPrecision

class torchmetrics.RetrievalRPrecision(empty_target_action='neg', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes IR R-Precision.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then R-Precision will be computed as the mean of the R-Precision over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Example

>>> from torchmetrics import RetrievalRPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalRPrecision()
>>> p2(preds, target, indexes=indexes)
tensor(0.7500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

RetrievalRecall

class torchmetrics.RetrievalRecall(empty_target_action='neg', k=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes IR Recall.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then Recall will be computed as the mean of the Recall over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Raises

ValueError – If k parameter is not None or an integer larger than 0

Example

>>> from torchmetrics import RetrievalRecall
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> r2 = RetrievalRecall(k=2)
>>> r2(preds, target, indexes=indexes)
tensor(0.7500)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

RetrievalFallOut

class torchmetrics.RetrievalFallOut(empty_target_action='pos', k=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Fall-out.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then Fall-out will be computed as the mean of the Fall-out over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a negative target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Raises

ValueError – If k parameter is not None or an integer larger than 0

Example

>>> from torchmetrics import RetrievalFallOut
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> fo = RetrievalFallOut(k=2)
>>> fo(preds, target, indexes=indexes)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

First concat state indexes, preds and target since they were stored as lists.

After that, compute list of groups that will help in keeping together predictions about the same query. Finally, for each group compute the _metric if the number of negative targets is at least 1, otherwise behave as specified by self.empty_target_action.

Return type

Tensor

RetrievalNormalizedDCG

class torchmetrics.RetrievalNormalizedDCG(empty_target_action='neg', k=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes Normalized Discounted Cumulative Gain.

Works with binary or positive integer target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long, int, bool or float tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then Normalized Discounted Cumulative Gain will be computed as the mean of the Normalized Discounted Cumulative Gain over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Raises

ValueError – If k parameter is not None or an integer larger than 0

Example

>>> from torchmetrics import RetrievalNormalizedDCG
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> ndcg = RetrievalNormalizedDCG()
>>> ndcg(preds, target, indexes=indexes)
tensor(0.8467)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

RetrievalHitRate

class torchmetrics.RetrievalHitRate(empty_target_action='neg', k=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Computes IR HitRate.

Works with binary target data. Accepts float predictions from a model output.

Forward accepts:

  • preds (float tensor): (N, ...)

  • target (long or bool tensor): (N, ...)

  • indexes (long tensor): (N, ...)

indexes, preds and target must have the same dimension. indexes indicate to which query a prediction belongs. Predictions will be first grouped by indexes and then the Hit Rate will be computed as the mean of the Hit Rate over each query.

Parameters
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • k (Optional[int]) – consider only the top k elements for each query (default: None, which considers them all)

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather. default: None

Raises

ValueError – If k parameter is not None or an integer larger than 0

Example

>>> from torchmetrics import RetrievalHitRate
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([True, False, False, False, True, False, True])
>>> hr2 = RetrievalHitRate(k=2)
>>> hr2(preds, target, indexes=indexes)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Text

BERTScore

class torchmetrics.BERTScore(model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=4, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Bert_score Evaluating Text Generation leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks.

This implemenation follows the original implementation from BERT_score.

Parameters
  • predictions – An iterable of predicted sentences.

  • references – An iterable of target sentences.

  • model_type – A name or a model path used to load transformers pretrained model.

  • num_layers (Optional[int]) – A layer of representation to use.

  • all_layers (bool) – An indication of whether the representation from all model’s layers should be used. If all_layers = True, the argument num_layers is ignored.

  • model (Optional[Module]) – A user’s own model. Must be of torch.nn.Module instance.

  • user_tokenizer (Optional[Any]) – A user’s own tokenizer used with the own model. This must be an instance with the __call__ method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing “input_ids” and “attention_mask” represented by torch.Tensor. It is up to the user’s model of whether “input_ids” is a torch.Tensor of input ids or embedding vectors. This tokenizer must prepend an equivalent of [CLS] token and append an equivalent of [SEP] token as transformers tokenizer does.

  • user_forward_fn (Optional[Callable[[Module, Dict[str, Tensor]], Tensor]]) – A user’s own forward function used in a combination with user_model. This function must take user_model and a python dictionary of containing “input_ids” and “attention_mask” represented by torch.Tensor as an input and return the model’s output represented by the single torch.Tensor.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings calculation.

  • idf (bool) – An indication whether normalization using inverse document frequencies should be used.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (int) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • return_hash (bool) – An indication of whether the correspodning hash_code should be returned.

  • lang (str) – A language of input sentences.

  • rescale_with_baseline (bool) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model from transformers model is used, the corresponding baseline is downloaded from the original bert-score package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score.

  • baseline_path (Optional[str]) – A path to the user’s own local csv/tsv file with the baseline scale.

  • baseline_url (Optional[str]) – A url path to the user’s own csv/tsv file with the baseline scale.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Returns

Python dictionary containing the keys precision, recall and f1 with corresponding values.

Example

>>> predictions = ["hello there", "general kenobi"]
>>> references = ["hello there", "master kenobi"]
>>> bertscore = BERTScore()
>>> bertscore.update(predictions=predictions,references=references)
>>> bertscore.compute()  
{'precision': [0.99..., 0.99...],
 'recall': [0.99..., 0.99...],
 'f1': [0.99..., 0.99...]}

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate BERT scores.

Return type

Dict[str, Union[List[float], str]]

Returns

Python dictionary containing the keys precision, recall and f1 with corresponding values.

update(predictions, references)[source]

Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working.

Parameters
  • predictions (List[str]) – An iterable of predicted sentences.

  • references (List[str]) – An iterable of predicted sentences.

Return type

None

BLEUScore

class torchmetrics.BLEUScore(n_gram=4, smooth=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculate BLEU score of machine translated text with one or more references.

Parameters
  • n_gram (int) – Gram value ranged from 1 to 4 (Default 4)

  • smooth (bool) – Whether or not to apply smoothing – see [2]

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Example

>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> metric = BLEUScore()
>>> metric(reference_corpus, translate_corpus)
tensor(0.7598)

References

[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu BLEU

[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate BLEU score.

Return type

Tensor

Returns

Tensor with BLEU Score

update(reference_corpus, translate_corpus)[source]

Compute Precision Scores.

Parameters
Return type

None

CharErrorRate

class torchmetrics.CharErrorRate(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Character error rate (CharErrorRate) is a metric of the performance of an automatic speech recognition (ASR) system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CharErrorRate of 0 being a perfect score. Character error rate can then be computed as:

CharErrorRate = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}

where:
  • S is the number of substitutions,

  • D is the number of deletions,

  • I is the number of insertions,

  • C is the number of correct characters,

  • N is the number of characters in the reference (N=S+D+C).

Compute CharErrorRate score of transcribed segments against references.

Parameters
  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Returns

(Tensor) Character error rate

Examples

>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> metric = CharErrorRate()
>>> metric(predictions, references)
tensor(0.3415)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate the character error rate.

Return type

Tensor

Returns

(Tensor) Character error rate

update(predictions, references)[source]

Store references/predictions for computing Character Error Rate scores.

Parameters
  • predictions (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • references (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type

None

ROUGEScore

class torchmetrics.ROUGEScore(newline_sep=None, use_stemmer=False, rouge_keys=('rouge1', 'rouge2', 'rougeL', 'rougeLsum'), decimal_places=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculate Rouge Score, used for automatic summarization. This implementation should imitate the behaviour of the rouge-score package Python ROUGE Implementation

Parameters
  • newline_sep (Optional[bool]) – New line separate the inputs. This argument has not been in use any more. It is deprecated in v0.6 and will be removed in v0.7.

  • use_stemmer (bool) – Use Porter stemmer to strip word suffixes to improve matching.

  • rouge_keys (Union[str, Tuple[str, …]]) – A list of rouge types to calculate. Keys that are allowed are rougeL, rougeLsum, and rouge1 through rouge9.

  • decimal_places (Optional[bool]) – The number of digits to round the computed the values to. This argument has not been in usd any more. It is deprecated in v0.6 and will be removed in v0.7.

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Example

>>> targets = "Is your name John".split()
>>> preds = "My name is John".split()
>>> rouge = ROUGEScore()   
>>> from pprint import pprint
>>> pprint(rouge(preds, targets))  
{'rouge1_fmeasure': 0.25,
 'rouge1_precision': 0.25,
 'rouge1_recall': 0.25,
 'rouge2_fmeasure': 0.0,
 'rouge2_precision': 0.0,
 'rouge2_recall': 0.0,
 'rougeL_fmeasure': 0.25,
 'rougeL_precision': 0.25,
 'rougeL_recall': 0.25,
 'rougeLsum_fmeasure': 0.25,
 'rougeLsum_precision': 0.25,
 'rougeLsum_recall': 0.25}
Raises
  • ValueError – If the python packages nltk is not installed.

  • ValueError – If any of the rouge_keys does not belong to the allowed set of keys.

References

[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin Rouge Detail

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate (Aggregate and provide confidence intervals) ROUGE score.

Return type

Dict[str, Tensor]

Returns

Python dictionary of rouge scores for each input rouge key.

update(preds, targets)[source]

Compute rouge scores.

Parameters
Return type

None

SacreBLEUScore

class torchmetrics.SacreBLEUScore(n_gram=4, smooth=False, tokenize='13a', lowercase=False, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Calculate BLEU score [1] of machine translated text with one or more references. This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu.

The SacreBLEU implementation differs from the NLTK BLEU implementation in tokenization techniques.

Parameters
  • n_gram (int) – Gram value ranged from 1 to 4 (Default 4)

  • smooth (bool) – Whether or not to apply smoothing – see [2]

  • tokenize (Literal[‘none’, ‘13a’, ‘zh’, ‘intl’, ‘char’]) – Tokenization technique to be used. (Default ‘13a’) Supported tokenization: [‘none’, ‘13a’, ‘zh’, ‘intl’, ‘char’]

  • lowercase (bool) – If True, BLEU score over lowercased text is calculated.

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) –

    Callback that performs the allgather operation on the metric state. When None, DDP

    will be used to perform the allgather.

    Raises:
    ValueError:

    If tokenize not one of ‘none’, ‘13a’, ‘zh’, ‘intl’ or ‘char’

    ValueError:

    If tokenize is set to ‘intl’ and regex is not installed

Example

>>> translate_corpus = ['the cat is on the mat']
>>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric = SacreBLEUScore()
>>> metric(reference_corpus, translate_corpus)
tensor(0.7598)

References

[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu BLEU

[2] A Call for Clarity in Reporting BLEU Scores by Matt Post.

[3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution

Initializes internal Module state, shared by both nn.Module and ScriptModule.

update(reference_corpus, translate_corpus)[source]

Compute Precision Scores.

Parameters
  • reference_corpus (Sequence[Sequence[str]]) – An iterable of iterables of reference corpus

  • translate_corpus (Sequence[str]) – An iterable of machine translated corpus

Return type

None

WER

class torchmetrics.WER(concatenate_texts=None, compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Word error rate (WER) is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score. Word error rate can then be computed as:

WER = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}

where:
  • S is the number of substitutions,

  • D is the number of deletions,

  • I is the number of insertions,

  • C is the number of correct words,

  • N is the number of words in the reference (N=S+D+C).

Compute WER score of transcribed segments against references.

Parameters
  • concatenate_texts (Optional[bool]) – Whether to concatenate all input texts or compute WER iteratively. This argument is deprecated in v0.6 and it will be removed in v0.7.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step. default: False

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather

Returns

(Tensor) Word error rate

Examples

>>> predictions = ["this is the prediction", "there is an other sample"]
>>> references = ["this is the reference", "there is another one"]
>>> metric = WER()
>>> metric(predictions, references)
tensor(0.5000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

compute()[source]

Calculate the word error rate.

Return type

Tensor

Returns

(Tensor) Word error rate

update(predictions, references)[source]

Store references/predictions for computing Word Error Rate scores.

Parameters
  • predictions (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • references (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type

None

Wrappers

Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic of the base metric.

BootStrapper

class torchmetrics.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Using Turn a Metric into a Bootstrapped That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever update or forward is called, all input tensors are resampled (with replacement) along the first dimension.

Parameters
  • base_metric (Metric) – base metric class to wrap

  • num_bootstraps (int) – number of copies to make of the base metric for bootstrapping

  • mean (bool) – if True return the mean of the bootstraps

  • std (bool) – if True return the standard diviation of the bootstraps

  • quantile (Union[float, Tensor, None]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higher

  • raw (bool) – if True, return all bootstrapped values

  • sampling_strategy (str) – Determines how to produce bootstrapped samplings. Either 'poisson' or multinomial. If 'possion' is chosen, the number of times each sample will be included in the bootstrap will be given by n\sim Poisson(\lambda=1), which approximates the true bootstrap distribution when the number of samples is large. If 'multinomial' is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.

  • compute_on_step (bool) – Forward only calls update() and return None if this is set to False.

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Example::
>>> from pprint import pprint
>>> from torchmetrics import Accuracy, BootStrapper
>>> _ = torch.manual_seed(123)
>>> base_metric = Accuracy()
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> pprint(output)
{'mean': tensor(0.2205), 'std': tensor(0.0859)}
compute()[source]

Computes the bootstrapped metric values.

Allways returns a dict of tensors, which can contain the following keys: mean, std, quantile and raw depending on how the class was initialized

Return type

Dict[str, Tensor]

update(*args, **kwargs)[source]

Updates the state of the base metric.

Any tensor passed in will be bootstrapped along dimension 0

Return type

None

MetricTracker

class torchmetrics.MetricTracker(metric, maximize=True)