Shortcuts

Welcome to TorchMetrics

TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

  • A standardized interface to increase reproducibility

  • Reduces Boilerplate

  • Distributed-training compatible

  • Rigorously tested

  • Automatic accumulation over batches

  • Automatic synchronization between multiple devices

You can use TorchMetrics in any PyTorch model, or within PyTorch Lightning to enjoy the following additional benefits:

  • Your data will always be placed on the same device as your metrics

  • You can log Metric objects directly in Lightning to reduce even more boilerplate


Install TorchMetrics

For pip users

pip install torchmetrics

Or directly from conda

conda install -c conda-forge torchmetrics

Quick Start

TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

  • A standardized interface to increase reproducibility

  • Reduces Boilerplate

  • Distributed-training compatible

  • Rigorously tested

  • Automatic accumulation over batches

  • Automatic synchronization between multiple devices

You can use TorchMetrics in any PyTorch model, or within PyTorch Lightning to enjoy additional features:

  • This means that your data will always be placed on the same device as your metrics.

  • Native support for logging metrics in Lightning to reduce even more boilerplate.

Install

You can install TorchMetrics using pip or conda:

# Python Package Index (PyPI)
pip install torchmetrics
# Conda
conda install -c conda-forge torchmetrics

Eventually if there is a missing PyTorch wheel for your OS or Python version you can simply compile PyTorch from source:

# Optional if you do not need compile GPU support
export USE_CUDA=0  # just to keep it simple
# you can install the latest state from master
pip install git+https://github.com/pytorch/pytorch.git
# OR set a particular PyTorch release
pip install git+https://github.com/pytorch/pytorch.git@<release-tag>
# and finalize with installing TorchMetrics
pip install torchmetrics

Using TorchMetrics

Functional metrics

Similar to torch.nn, most metrics have both a class-based and a functional version. The functional versions implement the basic operations required for computing each metric. They are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor. The code-snippet below shows a simple example for calculating the accuracy using the functional interface:

import torch
# import our library
import torchmetrics

# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target, task="multiclass", num_classes=5)
Module metrics

Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath. The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:

  • Accumulation of multiple batches

  • Automatic synchronization between multiple devices

  • Metric arithmetic

The code below shows how to use the class-based interface:

import torch
# import our library
import torchmetrics

# initialize metric
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Reseting internal state such that metric ready for new data
metric.reset()

Implementing your own metric

Implementing your own metric is as easy as subclassing a torch.nn.Module. Simply, subclass Metric and do the following:

  1. Implement __init__ where you call self.add_state for every internal state that is needed for the metrics computations

  2. Implement update method, where all logic that is necessary for updating metric states go

  3. Implement compute method, where the final metric computations happens

For practical examples and more info about implementing a metric, please see this page.

Development Environment

TorchMetrics provides a Devcontainer configuration for Visual Studio Code to use a Docker container as a pre-configured development environment. This avoids struggles setting up a development environment and makes them reproducible and consistent. Please follow the installation instructions and make yourself familiar with the container tutorials if you want to use them. In order to use GPUs, you can enable them within the .devcontainer/devcontainer.json file.

All TorchMetrics




Structure Overview

TorchMetrics is a Metrics API created for easy metric development and usage in PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of common metric implementations.

The metrics API provides update(), compute(), reset() functions to the user. The metric base class inherits torch.nn.Module which allows us to call metric(...) directly. The forward() method of the base Metric class serves the dual purpose of calling update() on its input and simultaneously returning the value of the metric over the provided input.

These metrics work with DDP in PyTorch and PyTorch Lightning by default. When .compute() is called in distributed mode, the internal state of each metric is synced and reduced across each process, so that the logic present in .compute() is applied to state information from all processes.

This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:

from torchmetrics.classification import BinaryAccuracy

train_accuracy = BinaryAccuracy()
valid_accuracy = BinaryAccuracy()

for epoch in range(epochs):
    for x, y in train_data:
        y_hat = model(x)

        # training step accuracy
        batch_acc = train_accuracy(y_hat, y)
        print(f"Accuracy of batch{i} is {batch_acc}")

    for x, y in valid_data:
        y_hat = model(x)
        valid_accuracy.update(y_hat, y)

    # total accuracy over all training batches
    total_train_accuracy = train_accuracy.compute()

    # total accuracy over all validation batches
    total_valid_accuracy = valid_accuracy.compute()

    print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
    print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

    # Reset metric states after each epoch
    train_accuracy.reset()
    valid_accuracy.reset()

Note

Metrics contain internal states that keep track of the data seen so far. Do not mix metric states across training, validation and testing. It is highly recommended to re-initialize the metric per mode as shown in the examples above.

Note

Metric states are not added to the models state_dict by default. To change this, after initializing the metric, the method .persistent(mode) can be used to enable (mode=True) or disable (mode=False) this behaviour.

Note

Due to specialized logic around metric states, we in general do not recommend that metrics are initialized inside other metrics (nested metrics), as this can lead to weird behaviour. Instead consider subclassing a metric or use torchmetrics.MetricCollection.

Metrics and devices

Metrics are simple subclasses of Module and their metric states behave similar to buffers and parameters of modules. This means that metrics states should be moved to the same device as the input of the metric:

from torchmetrics.classification import BinaryAccuracy

target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))

# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = BinaryAccuracy().to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0

However, when properly defined inside a Module or LightningModule the metric will be automatically moved to the same device as the module when using .to(device). Being properly defined means that the metric is correctly identified as a child module of the model (check .children() attribute of the model). Therefore, metrics cannot be placed in native python list and dict, as they will not be correctly identified as child modules. Instead of list use ModuleList and instead of dict use ModuleDict. Furthermore, when working with multiple metrics the native MetricCollection module can also be used to wrap multiple metrics.

from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy

class MyModule(torch.nn.Module):
    def __init__(self):
        ...
        # valid ways metrics will be identified as child modules
        self.metric1 = BinaryAccuracy()
        self.metric2 = nn.ModuleList(BinaryAccuracy())
        self.metric3 = nn.ModuleDict({'accuracy': BinaryAccuracy()})
        self.metric4 = MetricCollection([BinaryAccuracy()]) # torchmetrics build-in collection class

    def forward(self, batch):
        data, target = batch
        preds = self(data)
        ...
        val1 = self.metric1(preds, target)
        val2 = self.metric2[0](preds, target)
        val3 = self.metric3['accuracy'](preds, target)
        val4 = self.metric4(preds, target)

You can always check which device the metric is located on using the .device property.

Metrics and memory management

As stated before, metrics have states and those states take up a certain amount of memory depending on the metric. In general metrics can be divided into two categories when we talk about memory management:

  • Metrics with tensor states: These metrics only have states that are insteances of Tensor. When these kind of metrics are updated the values of those tensors are updated. Importantly the size of the tensors are constant meaning that regardless of how much data is passed to the metric, its memory footprint will not change.

  • Metrics with list states: These metrics have at least one state that is a list, which gets appended tensors as the metric is updated. Importantly the size of the list is therefore not constant and will grow as the metric is updated. The growth depends on the particular metric (some metrics only need to store a single value per sample, some much more).

You can always check the current metric state by accessing the .metric_state property, and checking if any of the states are lists.

import torch
from torchmetrics.regression import SpearmanCorrCoef

gen = torch.manual_seed(42)
metric = SpearmanCorrCoef()
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
{'preds': [tensor([0.8823, 0.9150])], 'target': [tensor([0.3829, 0.9593])]}
{'preds': [tensor([0.8823, 0.9150]), tensor([0.3904, 0.6009])], 'target': [tensor([0.3829, 0.9593]), tensor([0.2566, 0.7936])]}

In general we have a few recommendations for memory management:

  • When done with a metric, we always recommend calling the reset method. The reason for this being that the python garbage collector can struggle to totally clean the metric states if this is not done. In the worst case, this can lead to a memory leak if multiple instances of the same metric for different purposes are created in the same script.

  • Better to always try to reuse the same instance of a metric instead of initializing a new one. Calling the reset method returns the metric to its initial state, and can therefore be used to reuse the same instance. However, we still highly recommend to use different instances from training, validation and testing.

  • If only the results on a batch level are needed e.g no aggregation or alternatively if you have a small dataset that fits into iteration of evaluation, we can recommend using the functional API instead as it does not keep an internal state and memory is therefore freed after each call.

See Advanced metric settings for different advanced settings for controlling the memory footprint of metrics.

Metrics in Distributed Data Parallel (DDP) mode

When using metrics in Distributed Data Parallel (DDP) mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is not equally divisible by batch_size * num_processors. The added samples will always be replicates of datapoints already in your dataset. This is done to secure an equal load for all processes. However, this has the consequence that the calculated metric value will be slightly biased towards those replicated samples, leading to a wrong result.

During training and/or validation this may not be important, however it is highly recommended when evaluating the test dataset to only run on a single gpu or use a join context in conjunction with DDP to prevent this behaviour.

Metrics and 16-bit precision

Most metrics in our collection can be used with 16-bit precision (torch.half) tensors. However, we have found the following limitations:

  • In general pytorch had better support for 16-bit precision much earlier on GPU than CPU. Therefore, we recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6 where support for operations such as addition, subtraction, multiplication ect. was added.

  • Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their docstring, but they are also listed below:

You can always check the precision/dtype of the metric by checking the .dtype property.

Metric Arithmetics

Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.

For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary. It can now be done with:

first_metric = MyFirstMetric()
second_metric = MySecondMetric()

new_metric = first_metric + second_metric

new_metric.update(*args, **kwargs) now calls update of first_metric and second_metric. It forwards all positional arguments but forwards only the keyword arguments that are available in respective metric’s update declaration. Similarly new_metric.compute() now calls compute of first_metric and second_metric and adds the results up. It is important to note that all implemented operations always returns a new metric object. This means that the line first_metric == second_metric will not return a bool indicating if first_metric and second_metric is the same metric, but will return a new metric that checks if the first_metric.compute() == second_metric.compute().

This pattern is implemented for the following operators (with a being metrics and b being metrics, tensors, integer or floats):

  • Addition (a + b)

  • Bitwise AND (a & b)

  • Equality (a == b)

  • Floordivision (a // b)

  • Greater Equal (a >= b)

  • Greater (a > b)

  • Less Equal (a <= b)

  • Less (a < b)

  • Matrix Multiplication (a @ b)

  • Modulo (a % b)

  • Multiplication (a * b)

  • Inequality (a != b)

  • Bitwise OR (a | b)

  • Power (a ** b)

  • Subtraction (a - b)

  • True Division (a / b)

  • Bitwise XOR (a ^ b)

  • Absolute Value (abs(a))

  • Inversion (~a)

  • Negative Value (neg(a))

  • Positive Value (pos(a))

  • Indexing (a[0])

Note

Some of these operations are only fully supported from Pytorch v1.4 and onwards, explicitly we found: add, mul, rmatmul, rsub, rmod

MetricCollection

In many cases it is beneficial to evaluate the model output by multiple metrics. In this case the MetricCollection class may come in handy. It accepts a sequence of metrics and wraps these into a single callable metric class, with the same interface as any other metric.

Example:

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
    MulticlassAccuracy(num_classes=3, average="micro"),
    MulticlassPrecision(num_classes=3, average="macro"),
    MulticlassRecall(num_classes=3, average="macro")
])
print(metric_collection(preds, target))
{'MulticlassAccuracy': tensor(0.1250),
 'MulticlassPrecision': tensor(0.0667),
 'MulticlassRecall': tensor(0.1111)}

Similarly it can also reduce the amount of code required to log multiple metrics inside your LightningModule. In most cases we just have to replace self.log with self.log_dict.

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

class MyModule(LightningModule):
    def __init__(self, num_classes: int):
        super().__init__()
        metrics = MetricCollection([
            MulticlassAccuracy(num_classes), MulticlassPrecision(num_classes), MulticlassRecall(num_classes)
        ])
        self.train_metrics = metrics.clone(prefix='train_')
        self.valid_metrics = metrics.clone(prefix='val_')

    def training_step(self, batch, batch_idx):
        logits = self(x)
        # ...
        output = self.train_metrics(logits, y)
        # use log_dict instead of log
        # metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
        self.log_dict(output)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        # ...
        self.valid_metrics.update(logits, y)

    def validation_epoch_end(self, outputs):
        # use log_dict instead of log
        # metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
        output = self.valid_metrics.compute()
        self.log_dict(output)
        # remember to reset metrics at the end of the epoch
        self.valid_metrics.reset()

Note

MetricCollection as default assumes that all the metrics in the collection have the same call signature. If this is not the case, input that should be given to different metrics can given as keyword arguments to the collection.

An additional advantage of using the MetricCollection object is that it will automatically try to reduce the computations needed by finding groups of metrics that share the same underlying metric state. If such a group of metrics is found only one of them is actually updated and the updated state will be broadcasted to the rest of the metrics within the group. In the example above, this will lead to a 2x-3x lower computational cost compared to disabling this feature in the case of the validation metrics where only update is called (this feature does not work in combination with forward). However, this speedup comes with a fixed cost upfront, where the state-groups have to be determined after the first update. In case the groups are known beforehand, these can also be set manually to avoid this extra cost of the dynamic search. See the compute_groups argument in the class docs below for more information on this topic.

class torchmetrics.MetricCollection(metrics, *additional_metrics, prefix=None, postfix=None, compute_groups=True)[source]

MetricCollection class can be used to chain metrics that have the same call pattern into one single class.

Parameters:
  • metrics (Union[Metric, Sequence[Metric], Dict[str, Metric]]) –

    One of the following

    • list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name as key for output dict. Therefore, two metrics of the same class cannot be chained this way.

    • arguments: similar to passing in as a list, metrics passed in as arguments will use their metric class name as key for the output dict.

    • dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. Use this format if you want to chain together multiple of the same metric with different parameters. Note that the keys in the output dict will be sorted alphabetically.

  • prefix (Optional[str]) – a string to append in front of the keys of the output dict

  • postfix (Optional[str]) – a string to append after the keys of the output dict

  • compute_groups (Union[bool, List[List[str]]]) – By default the MetricCollection will try to reduce the computations needed for the metrics in the collection by checking if they belong to the same compute group. All metrics in a compute group share the same metric state and are therefore only different in their compute step e.g. accuracy, precision and recall can all be computed from the true positives/negatives and false positives/negatives. By default, this argument is True which enables this feature. Set this argument to False for disabling this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself.

Note

The compute groups feature can significatly speedup the calculation of metrics under the right conditions. First, the feature is only available when calling the update method and not when calling forward method due to the internal logic of forward preventing this. Secondly, since we compute groups share metric states by reference, calling .items(), .values() etc. on the metric collection will break this reference and a copy of states are instead returned in this case (reference will be reestablished on the next call to update).

Note

Metric collections can be nested at initilization (see last example) but the output of the collection will still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection.

Raises:
  • ValueError – If one of the elements of metrics is not an instance of pl.metrics.Metric.

  • ValueError – If two elements in metrics have the same name.

  • ValueError – If metrics is not a list, tuple or a dict.

  • ValueError – If metrics is dict and additional_metrics are passed in.

  • ValueError – If prefix is set and it is not a string.

  • ValueError – If postfix is set and it is not a string.

Example::

In the most basic case, the metrics can be passed in as a list or tuple. The keys of the output dict will be the same as the class name of the metric:

>>> from torch import tensor
>>> from pprint import pprint
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
>>> target = tensor([0, 2, 0, 2, 0, 1, 0, 2])
>>> preds = tensor([2, 1, 2, 0, 1, 2, 2, 2])
>>> metrics = MetricCollection([MulticlassAccuracy(num_classes=3, average='micro'),
...                             MulticlassPrecision(num_classes=3, average='macro'),
...                             MulticlassRecall(num_classes=3, average='macro')])
>>> metrics(preds, target)  
{'MulticlassAccuracy': tensor(0.1250),
 'MulticlassPrecision': tensor(0.0667),
 'MulticlassRecall': tensor(0.1111)}
Example::

Alternatively, metrics can be passed in as arguments. The keys of the output dict will be the same as the class name of the metric:

>>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'),
...                            MulticlassPrecision(num_classes=3, average='macro'),
...                            MulticlassRecall(num_classes=3, average='macro'))
>>> metrics(preds, target)  
{'MulticlassAccuracy': tensor(0.1250),
 'MulticlassPrecision': tensor(0.0667),
 'MulticlassRecall': tensor(0.1111)}
Example::

If multiple of the same metric class (with different parameters) should be chained together, metrics can be passed in as a dict and the output dict will have the same keys as the input dict:

>>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'),
...                             'macro_recall': MulticlassRecall(num_classes=3, average='macro')})
>>> same_metric = metrics.clone()
>>> pprint(metrics(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
>>> pprint(same_metric(preds, target))
{'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
Example::

Metric collections can also be nested up to a single time. The output of the collection will still be a single dict with the prefix and postfix arguments from the nested collection:

>>> metrics = MetricCollection([
...     MetricCollection([
...         MulticlassAccuracy(num_classes=3, average='macro'),
...         MulticlassPrecision(num_classes=3, average='macro')
...     ], postfix='_macro'),
...     MetricCollection([
...         MulticlassAccuracy(num_classes=3, average='micro'),
...         MulticlassPrecision(num_classes=3, average='micro')
...     ], postfix='_micro'),
... ], prefix='valmetrics/')
>>> pprint(metrics(preds, target))  
{'valmetrics/MulticlassAccuracy_macro': tensor(0.1111),
 'valmetrics/MulticlassAccuracy_micro': tensor(0.1250),
 'valmetrics/MulticlassPrecision_macro': tensor(0.0667),
 'valmetrics/MulticlassPrecision_micro': tensor(0.1250)}
Example::

The compute_groups argument allow you to specify which metrics should share metric state. By default, this will automatically be derived but can also be set manually.

>>> metrics = MetricCollection(
...     MulticlassRecall(num_classes=3, average='macro'),
...     MulticlassPrecision(num_classes=3, average='macro'),
...     MeanSquaredError(),
...     compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']]
... )
>>> metrics.update(preds, target)
>>> pprint(metrics.compute())
{'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
>>> pprint(metrics.compute_groups)
{0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}
add_metrics(metrics, *additional_metrics)[source]

Add new metrics to Metric Collection.

Return type:

None

clone(prefix=None, postfix=None)[source]

Make a copy of the metric collection.

Parameters:
  • prefix (Optional[str]) – a string to append in front of the metric keys

  • postfix (Optional[str]) – a string to append after the keys of the output dict.

Return type:

MetricCollection

items(keep_base=False, copy_state=True)[source]

Return an iterable of the ModuleDict key/value pairs.

Parameters:
  • keep_base (bool) – Whether to add prefix/postfix on the collection.

  • copy_state (bool) – If metric states should be copied between metrics in the same compute group or just passed by reference

Return type:

Iterable[Tuple[str, Metric]]

keys(keep_base=False)[source]

Return an iterable of the ModuleDict key.

Parameters:

keep_base (bool) – Whether to add prefix/postfix on the items collection.

Return type:

Iterable[Hashable]

persistent(mode=True)[source]

Change if metric states should be saved to its state_dict after initialization.

Return type:

None

plot(val=None, ax=None, together=False)[source]

Plot a single or multiple values from the metric.

The plot method has two modes of operation. If argument together is set to False (default), the .plot method of each metric will be called individually and the result will be list of figures. If together is set to True, the values of all metrics will instead be plotted in the same figure.

Parameters:
  • val (Union[Dict, Sequence[Dict], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Union[Axes, Sequence[Axes], None]) – Either a single instance of matplotlib axis object or an sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects. If not provided, will create a new. If argument together is set to True, a single object is expected. If together is set to False, the number of axis objects needs to be the same lenght as the number of metrics in the collection.

  • together (bool) – If True, will plot all metrics in the same axis. If False, will plot each metric in a separate

Return type:

Sequence[Tuple[Figure, Union[Axes, ndarray]]]

Returns:

Either instal tupel of Figure and Axes object or an sequence of tuples with Figure and Axes object for each metric in the collection.

Raises:
  • ModuleNotFoundError – If matplotlib is not installed

  • ValueError – If together is not an bool

  • ValueError – If ax is not an instance of matplotlib axis object or a sequence of matplotlib axis objects

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
>>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
>>> metrics.update(torch.rand(10), torch.randint(2, (10,)))
>>> fig_ax_ = metrics.plot()
_images/overview-1_00.png
_images/overview-1_01.png
_images/overview-1_02.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
>>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
>>> values = []
>>> for _ in range(10):
...     values.append(metrics(torch.rand(10), torch.randint(2, (10,))))
>>> fig_, ax_ = metrics.plot(values, together=True)
_images/overview-2.png
reset()[source]

Call reset for each metric sequentially.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as torch.dtype or string.

Return type:

MetricCollection

values(copy_state=True)[source]

Return an iterable of the ModuleDict values.

Parameters:

copy_state (bool) – If metric states should be copied between metrics in the same compute group or just passed by reference

Return type:

Iterable[Metric]

property compute_groups: Dict[int, List[str]]

Return a dict with the current compute groups in the collection.

Module vs Functional Metrics

The functional metrics follow the simple paradigm input in, output out. This means they don’t provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.

Also, the integration within other parts of PyTorch Lightning will never be as tight as with the Module-based interface. If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the Module interface.

Metrics and differentiability

Metrics support backpropagation, if all computations involved in the metric calculation are differentiable. All modular metric classes have the property is_differentiable that determines if a metric is differentiable or not.

However, note that the cached state is detached from the computational graph and cannot be back-propagated. Not doing this would mean storing the computational graph for each update call, which can lead to out-of-memory errors. In practise this means that:

MyMetric.is_differentiable  # returns True if metric is differentiable
metric = MyMetric()
val = metric(pred, target)  # this value can be back-propagated
val = metric.compute()  # this value cannot be back-propagated

A functional metric is differentiable if its corresponding modular metric is differentiable.

Metrics and hyperparameter optimization

If you want to directly optimize a metric it needs to support backpropagation (see section above). However, if you are just interested in using a metric for hyperparameter tuning and are not sure if the metric should be maximized or minimized, all modular metric classes have the higher_is_better property that can be used to determine this:

# returns True because accuracy is optimal when it is maximized
torchmetrics.classification.Accuracy.higher_is_better

# returns False because the mean squared error is optimal when it is minimized
torchmetrics.MeanSquaredError.higher_is_better

Advanced metric settings

The following is a list of additional arguments that can be given to any metric class (in the **kwargs argument) that will alter how metric states are stored and synced.

If you are running metrics on GPU and are encountering that you are running out of GPU VRAM then the following argument can help:

  • compute_on_cpu: will automatically move the metric states to cpu after calling update, making sure that GPU memory is not filling up. The consequence will be that the compute method will be called on CPU instead of GPU. Only applies to metric states that are lists.

  • compute_with_cache: This argument indicates if the result after calling the compute method should be cached. By default this is True meaning that repeated calls to compute (with no change to the metric state inbetween) does not recompute the metric but just returns the cache. By setting it to False the metric will be recomputed every time compute is called, but it can also help clean up a bit of memory.

If you are running in a distributed environment, TorchMetrics will automatically take care of the distributed synchronization for you. However, the following three keyword arguments can be given to any metric class for further control over the distributed aggregation:

  • sync_on_compute: This argument is an bool that indicates if the metrics should automatically sync between devices whenever the compute method is called. By default this is True, but by setting this to False you can manually control when the synchronization happens.

  • dist_sync_on_step: This argument is bool that indicates if the metric should synchronize between different devices every time forward is called. Setting this to True is in general not recommended as synchronization is an expensive operation to do after each batch.

  • process_group: By default we synchronize across the world i.e. all processes being computed on. You can provide an torch._C._distributed_c10d.ProcessGroup in this argument to specify exactly what devices should be synchronized over.

  • dist_sync_fn: By default we use torch.distributed.all_gather() to perform the synchronization between devices. Provide another callable function for this argument to perform custom distributed synchronization.

Plotting

Note

The visualization/plotting interface of Torchmetrics requires matplotlib to be installed. Install with either pip install matplotlib or pip install 'torchmetrics[visual]'. If the latter option is chosen the Scienceplot package is also installed and all plots in Torchmetrics will default to using that style.

Torchmetrics comes with build-in support for quick visualization of your metrics, by simply using the .plot method that all modular metrics implement. This method provides a consistent interface for basic plotting of all metrics.

metric = AnyMetricYouLike()
for _ in range(num_updates):
    metric.update(preds[i], target[i])
fig, ax = metric.plot()

.plot will always return two objects: fig is an instance of Figure which contains figure level attributes and ax is an instance of Axes that contains all the elements of the plot. These two objects allow to change attributes of the plot after it is created. For example, if you want to make the fontsize of the x-axis a bit bigger and give the figure a nice title and finally save it on the above example, it could be do like this:

ax.set_fontsize(fs=20)
fig.set_title("This is a nice plot")
fig.savefig("my_awesome_plot.png")

If you want to include a Torchmetrics plot in a bigger figure that has subfigures and subaxes, all .plot methods support an optional ax argument where you can pass in the subaxes you want the plot to be inserted into:

# combine plotting of two metrics into one figure
fig, ax = plt.subplots(nrows=1, ncols=2)
metric1 = Metric1()
metric2 = Metric2()
for _ in range(num_updates):
    metric1.update(preds[i], target[i])
    metric2.update(preds[i], target[i])
metric1.plot(ax=ax[0])
metric2.plot(ax=ax[1])

Plotting a single step

At the most basic level the .plot method can be used to plot the value from a single step. This can be done in two ways: * Either .plot method is called with no input, and internally metric.compute() is called and that value is plotted * .plot is called on a single returned value by the metric, for example from metric.forward()

In both cases it will generate a plot like this (Accuracy as an example):

metric = torchmetrics.Accuracy(task="binary")
for _ in range(num_updates):
    metric.update(torch.rand(10,), torch.randint(2, (10,)))
fig, ax = metric.plot()
_images/binary_accuracy.png

A single point plot is not that informative in itself, but if available we will try to include additional information such as the lower and upper bounds the particular metric can take and if the metric should be minimized or maximized to be optimal. This is true for all metrics that return a scalar tensor. Some metrics return multiple values (such as an tensor with multiple elements or a dict of scalar tensors), and in that case calling .plot will return a figure similar to this:

metric = torchmetrics.Accuracy(task="multiclass", num_classes=3, average=None)
for _ in range(num_updates):
    metric.update(torch.randint(3, (10,)), torch.randint(3, (10,)))
fig, ax = metric.plot()
_images/multiclass_accuracy.png

Here, each element is assumed to be an independent metric and plotted as its own point for comparing. The above is true for all metrics that return a scalar tensor, but if the metric returns a tensor with multiple elements then the .plot method will return a specialized plot for that particular metric. Take for example the ConfusionMatrix metric:

metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)
for _ in range(num_updates):
    metric.update(torch.randint(3, (10,)), torch.randint(3, (10,)))
fig, ax = metric.plot()
_images/confusion_matrix.png

If you prefer to use the functional interface of Torchmetrics, you can also plot the values returned by the functional. However, you would still need to initialize the corresponding metric class to get the information about the metric:

plot_class = torchmetrics.Accuracy(task="multiclass", num_classes=3)
value = torchmetrics.functional.accuracy(
    torch.randint(3, (10,)), torch.randint(3, (10,)), num_classes=3
)
fig, ax = plot_class.plot(value)

Plotting multi steps

In the above examples we have only plotted a single step/single value, but it is also possible to plot multiple steps from the same metric. This is often the case when training a machine learning model, where you are tracking one or multiple metrics that you want to plot as they are changing over time. This can be done by providing a sequence of outputs from any metric, computed using metric.forward or metric.compute. For example, if we want to plot the accuracy of a model over time, we could do it like this:

metric = torchmetrics.Accuracy(task="binary")
values = [ ]
for step in range(num_steps):
    for _ in range(num_updates):
        metric.update(preds(step), target(step))
    values.append(metric.compute())  # save value
    metric.reset()
fig, ax = metric.plot(values)
_images/binary_accuracy_multistep.png

Do note that metrics that do not return simple scalar tensors, such as ConfusionMatrix, ROC that have specialized visualzation does not support plotting multiple steps, out of the box and the user needs to manually plot the values for each step.

Plotting a collection of metrics

MetricCollection also supports .plot method and by default it works by just returning a collection of plots for all its members. Thus, instead of returning a single (fig, ax) pair, calling .plot method of MetricCollection will return a sequence of such pairs, one for each member in the collection. In the following example we are forming a collection of binary classification metrics and redirecting the output of .plot to different subplots:

collection = torchmetrics.MetricCollection(
    torchmetrics.Accuracy(task="binary"),
    torchmetrics.Recall(task="binary"),
    torchmetrics.Precision(task="binary"),
)
fig, ax = plt.subplots(nrows=1, ncols=3)
values = [ ]
for step in range(num_steps):
    for _ in range(num_updates):
        collection.update(preds(step), target(step))
    values.append(collection.compute())
    collection.reset()
collection.plot(val=values, ax=ax)
_images/binary_accuracy_multistep.png

However, the plot method of MetricCollection also supports an additional argument called together that will automatically try to plot all the metrics in the collection together in the same plot (with appropriate labels). This is only possible if all the metrics in the collection return a scalar tensor.

collection = torchmetrics.MetricCollection(
    torchmetrics.Accuracy(task="binary"),
    torchmetrics.Recall(task="binary"),
    torchmetrics.Precision(task="binary"),
)
values = [ ]
fig, ax = plt.subplots(figsize=(6.8, 4.8))
for step in range(num_steps):
    for _ in range(num_updates):
        collection.update(preds(step), target(step))
    values.append(collection.compute())
    collection.reset()
collection.plot(val=values, together=True)
_images/collection_binary_together.png

Advance example

In the following we are going to show how to use the .plot method to create a more advanced plot. We are going to combine the functionality of several metrics using MetricCollection and plot them together. In addition we are going to rely on MetricTracker to keep track of the metrics over multiple steps.

# Define collection that is a mix of metrics that return a scalar tensors and not
confmat = torchmetrics.ConfusionMatrix(task="binary")
roc = torchmetrics.ROC(task="binary")
collection = torchmetrics.MetricCollection(
    torchmetrics.Accuracy(task="binary"),
    torchmetrics.Recall(task="binary"),
    torchmetrics.Precision(task="binary"),
    confmat,
    roc,
)

# Define tracker over the collection to easy keep track of the metrics over multiple steps
tracker = torchmetrics.wrappers.MetricTracker(collection)

# Run "training" loop
for step in range(num_steps):
    tracker.increment()
    for _ in range(N):
        tracker.update(preds(step), target(step))

# Extract all metrics from all steps
all_results = tracker.compute_all()

# Constuct a single figure with appropriate layout for all metrics
fig = plt.figure(layout="constrained")
ax1 = plt.subplot(2, 2, 1)
ax2 = plt.subplot(2, 2, 2)
ax3 = plt.subplot(2, 2, (3, 4))

# ConfusionMatrix and ROC we just plot the last step, notice how we call the plot method of those metrics
confmat.plot(val=all_results[-1]['BinaryConfusionMatrix'], ax=ax1)
roc.plot(all_results[-1]["BinaryROC"], ax=ax2)

# For the remainig we plot the full history, but we need to extract the scalar values from the results
scalar_results = [
    {k: v for k, v in ar.items() if isinstance(v, torch.Tensor) and v.numel() == 1} for ar in all_results
]
tracker.plot(val=scalar_results, ax=ax3)
_images/tracker_binary.png

Implementing a Metric

To implement your own custom metric, subclass the base Metric class and implement the following methods:

  • __init__(): Each state variable should be called using self.add_state(...).

  • update(): Any code needed to update the state given any inputs to the metric.

  • compute(): Computes a final value from the state of the metric.

We provide the remaining interface, such as reset() that will make sure to correctly reset all metric states that have been added using add_state. You should therefore not implement reset() yourself. Additionally, adding metric states with add_state will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are synchronized across distributed processes, refer to add_state() docs from the base Metric class.

Example implementation:

from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total

Additionally you may want to set the class properties: is_differentiable, higher_is_better and full_state_update. Note that none of them are strictly required for the metric to work.

from torchmetrics import Metric

class MyMetric(Metric):
    # Set to True if the metric is differentiable else set to False
    is_differentiable: Optional[bool] = None

    # Set to True if the metric reaches it optimal value when the metric is maximized.
    # Set to False if it when the metric is minimized.
    higher_is_better: Optional[bool] = True

    # Set to True if the metric during 'update' requires access to the global metric
    # state for its calculations. If not, setting this to False indicates that all
    # batch states are independent and we will optimize the runtime of 'forward'
    full_state_update: bool = True

Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the .plot method. By default this method will raise NotImplementedError but can be implemented by the user to provide a custom plot for the metric. For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal ._plot method can be used, that provides the common plotting functionality for most metrics in torchmetrics.

from torchmetrics import Metric
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

class MyMetric(Metric):
    ...

    def plot(
        self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
    ) -> _PLOT_OUT_TYPE:
        return self._plot(val, ax)

If the metric returns a more complex output, a custom implementation of the plot method is required. For more details on the plotting API, see the this page .

Internal implementation details

This section briefly describes how metrics work internally. We encourage looking at the source code for more info. Internally, TorchMetrics wraps the user defined update() and compute() method. We do this to automatically synchronize and reduce metric states across multiple devices. More precisely, calling update() does the following internally:

  1. Clears computed cache.

  2. Calls user-defined update().

Similarly, calling compute() does the following internally:

  1. Syncs metric states between processes.

  2. Reduce gathered metric states.

  3. Calls the user defined compute() method on the gathered metric states.

  4. Cache computed result.

From a user’s standpoint this has one important side-effect: computed results are cached. This means that no matter how many times compute is called after one and another, it will continue to return the same result. The cache is first emptied on the next call to update.

forward serves the dual purpose of both returning the metric on the current data and updating the internal metric state for accumulating over multiple batches. The forward() method achieves this by combining calls to update, compute and reset. Depending on the class property full_state_update, forward can behave in two ways:

  1. If full_state_update is True it indicates that the metric during update requires access to the full metric state and we therefore need to do two calls to update to secure that the metric is calculated correctly

    1. Calls update() to update the global metric state (for accumulation over multiple batches)

    2. Caches the global state.

    3. Calls reset() to clear global metric state.

    4. Calls update() to update local metric state.

    5. Calls compute() to calculate metric for current batch.

    6. Restores the global state.

  2. If full_state_update is False (default) the metric state of one batch is completly independent of the state of other batches, which means that we only need to call update once.

    1. Caches the global state.

    2. Calls reset the metric to its default state

    3. Calls update to update the state with local batch statistics

    4. Calls compute to calculate the metric for the current batch

    5. Reduce the global state and batch state into a single state that becomes the new global state

If implementing your own metric, we recommend trying out the metric with full_state_update class property set to both True and False. If the results are equal, then setting it to False will usually give the best performance.


class torchmetrics.Metric(**kwargs)[source]

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes

The three core methods of the base class are * add_state() * forward() * reset()

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update() * compute()

Parameters:

kwargs (Any) –

additional keyword arguments, see Advanced metric settings for more info.

  • compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step: If metric state should synchronize on forward(). Default is False

  • process_group: The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute: If metric state should synchronize when compute is called. Default is True

  • compute_with_cache: If results from compute should be cached. Default is False

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Add metric state variable. Only used by subclasses.

Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if name is "my_state" then its value can be accessed from an instance metric as metric.my_state. Metric states behave like buffers and parameters of Module as they are also updated when .to() is called. Unlike parameters and buffers, metric states are not by default saved in the modules state_dict.

Parameters:
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Return type:

None

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is Tensor, the synced value will be a stacked Tensor across the process dimension if the metric state was a Tensor. The original Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Note

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Raises:
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", "min", "max" or None.

clone()[source]

Make a copy of the metric.

Return type:

Metric

abstract compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Return type:

Any

double()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

float()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

forward(*args, **kwargs)[source]

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Parameters:
  • args (Any) – Any arguments as required by the metric update method.

  • kwargs (Any) – Any keyword arguments as required by the metric update method.

Return type:

Any

Returns:

The output of the compute method evaluated on the current batch.

Raises:

TorchMetricsUserError – If the metric is already synced and forward is called again.

half()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

persistent(mode=False)[source]

Change post-init if metric states should be saved to its state_dict.

Return type:

None

plot(*_, **__)[source]

Override this method plot the metric value.

Return type:

Any

reset()[source]

Reset metric state variables to their default value.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as string or dtype object

Return type:

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Get the current state of metric as an dictionary.

Parameters:
  • destination (Optional[Dict[str, Any]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned.

  • prefix (str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.

  • keep_vars (bool) – by default the Tensor returned in the state dict are detached from autograd. If set to True, detaching will not be performed.

Return type:

Dict[str, Any]

sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Raises:

TorchMetricsUserError – If the metric is already synced and sync is called again.

Return type:

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]

Context manager to synchronize states.

This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the syncronized state.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type:

Generator

type(dst_type)[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters:

should_unsync (bool) – Whether to perform unsync

Return type:

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type:

None

property device: device

Return the device of the metric.

property metric_state: Dict[str, Union[List[Tensor], Tensor]]

Get the current state of the metric.

property update_called: bool

Returns True if update or forward has been called initialization or last reset.

property update_count: int

Get the number of times update and/or forward has been called since initialization or last reset.

Contributing your metric to TorchMetrics

Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation and tests gets formatted in the following way:

  1. Start by reading our contribution guidelines.

  2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should be put into a single file placed under torchmetrics/functional/"domain"/"new_metric".py where domain is the type of metric (classification, regression, nlp etc) and new_metric is the name of the metric. In this file, there should be the following three functions:

  1. _new_metric_update(...): everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.

  2. _new_metric_compute(...): all remaining logic.

  3. new_metric(...): essentially wraps the _update and _compute private functions into one public function that makes up the functional interface for the metric.

Note

The functional accuracy metric is a great example of this division of logic.

  1. In a corresponding file placed in torchmetrics/"domain"/"new_metric".py create the module interface:

  1. Create a new module metric by subclassing torchmetrics.Metric.

  2. In the __init__ of the module call self.add_state for as many metric states are needed for the metric to proper accumulate metric statistics.

  3. The module interface should essentially call the private _new_metric_update(...) in its update method and similarly the _new_metric_compute(...) function in its compute. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.

Note

The module Accuracy metric that corresponds to the above functional example showcases these steps.

  1. Remember to add binding to the different relevant __init__ files.

  2. Testing is key to keeping torchmetrics trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn, scipy etc).

  1. Create a testing file in unittests/"domain"/test_"new_metric".py. Only one file is needed as it is intended to test both the functional and module interface.

  2. In that file, start by defining a number of test inputs that your metric should be evaluated on.

  3. Create a testclass class NewMetric(MetricTester) that inherits from tests.helpers.testers.MetricTester. This testclass should essentially implement the test_"new_metric"_class and test_"new_metric"_fn methods that respectively tests the module interface and the functional interface.

  4. The testclass should be parameterized (using @pytest.mark.parametrize) by the different test inputs defined initially. Additionally, the test_"new_metric"_class method should also be parameterized with an ddp parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these such that different combinations of inputs and parameters gets tested.

  5. (optional) If your metric raises any exception, please add tests that showcase this.

Note

The test file for accuracy metric shows how to implement such tests.

If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]

TorchMetrics in PyTorch Lightning

TorchMetrics was originally created as part of PyTorch Lightning, a powerful deep learning research framework designed for scaling models without boilerplate.

Note

TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks up-to-date for the best experience.

While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits:

  • Modular metrics are automatically placed on the correct device when properly defined inside a LightningModule. This means that your data will always be placed on the same device as your metrics. No need to call .to(device) anymore!

  • Native support for logging metrics in Lightning using self.log inside your LightningModule.

  • The .reset() method of the metric will automatically be called at the end of an epoch.

The example below shows how to use a metric in your LightningModule:

class MyModel(LightningModule):

    def __init__(self, num_classes):
        ...
        self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        # log step metric
        self.accuracy(preds, y)
        self.log('train_acc_step', self.accuracy)
        ...

    def on_train_epoch_end(self):
        # log epoch metric
        self.log('train_acc_epoch', self.accuracy)

Metric logging in Lightning happens through the self.log or self.log_dict method. Both methods only support the logging of scalar-tensors. While the vast majority of metrics in torchmetrics returns a scalar tensor, some metrics such as ConfusionMatrix, ROC, MeanAveragePrecision, ROUGEScore return outputs that are non-scalar tensors (often dicts or list of tensors) and should therefore be dealt with separately. For info about the return type and shape please look at the documentation for the compute method for each metric you want to log.

Logging TorchMetrics

Logging metrics can be done in two ways: either logging the metric object directly or the computed metric values. When Metric objects, which return a scalar tensor are logged directly in Lightning using the LightningModule self.log method, Lightning will log the metric based on on_step and on_epoch flags present in self.log(...). If on_epoch is True, the logger automatically logs the end of epoch metric value by calling .compute().

Note

sync_dist, sync_dist_group and reduce_fx flags from self.log(...) don’t affect the metric logging in any manner. The metric class contains its own distributed synchronization logic.

This however is only true for metrics that inherit the base class Metric, and thus the functional metric API provides no support for in-built distributed synchronization or reduction functions.

class MyModule(LightningModule):

    def __init__(self, num_classes):
        ...
        self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        self.train_acc(preds, y)
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc(logits, y)
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

As an alternative to logging the metric object and letting Lightning take care of when to reset the metric etc. you can also manually log the output of the metrics.

class MyModule(LightningModule):

    def __init__(self):
        ...
        self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        batch_value = self.train_acc(preds, y)
        self.log('train_acc_step', batch_value)

    def on_train_epoch_end(self):
        self.train_acc.reset()

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc.update(logits, y)

    def on_validation_epoch_end(self, outputs):
        self.log('valid_acc_epoch', self.valid_acc.compute())
        self.valid_acc.reset()

Note that logging metrics this way will require you to manually reset the metrics at the end of the epoch yourself. In general, we recommend logging the metric object to make sure that metrics are correctly computed and reset. Additionally, we highly recommend that the two ways of logging are not mixed as it can lead to wrong results.

Note

When using any Modular metric, calling self.metric(...) or self.metric.forward(...) serves the dual purpose of calling self.metric.update() on its input and simultaneously returning the metric value over the provided input. So if you are logging a metric only on epoch-level (as in the example above), it is recommended to call self.metric.update() directly to avoid the extra computation.

class MyModule(LightningModule):

    def __init__(self, num_classes):
        ...
        self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc.update(logits, y)
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

Common Pitfalls

The following contains a list of pitfalls to be aware of:

  • Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders, it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds for using seperate metrics for training, validation and testing.

class MyModule(LightningModule):

    def __init__(self, num_classes):
        ...
        self.val_acc = nn.ModuleList(
            [torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) for _ in range(2)]
        )

    def val_dataloader(self):
        return [DataLoader(...), DataLoader(...)]

    def validation_step(self, batch, batch_idx, dataloader_idx):
        x, y = batch
        preds = self(x)
        ...
        self.val_acc[dataloader_idx](preds, y)
        self.log('val_acc', self.val_acc[dataloader_idx])
  • Mixing the two logging methods by calling self.log("val", self.metric) in {training}/{val}/{test}_step method and then calling self.log("val", self.metric.compute()) in the corresponding {training}/{val}/{test}_epoch_end method. Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading to errors or nonsense results.

  • Calling self.log("val", self.metric(preds, target)) with the intention of logging the metric object. Because self.metric(preds, target) corresponds to calling the forward method, this will return a tensor and not the metric object. Such logging will be wrong in this case. Instead it is important to seperate into seperate lines:

def training_step(self, batch, batch_idx):
    x, y = batch
    preds = self(x)
    ...
    # log step metric
    self.accuracy(preds, y)  # compute metrics
    self.log('train_acc_step', self.accuracy)  # log metric object
  • Using MetricTracker wrapper with Lightning is a special case, because the wrapper in itself is not a metric i.e. it does not inherit from the base Metric class but instead from ModuleList. Thus, to log the output of this metric one needs to manually log the returned values (not the object) using self.log and for epoch level logging this should be done in the appropriate on_***_epoch_end method.

Concatenation

Module Interface

class torchmetrics.aggregation.CatMetric(nan_strategy='warn', **kwargs)[source]

Concatenate a stream of values.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with concatenated values over all input received

Parameters:
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import CatMetric
>>> metric = CatMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor([1., 2., 3.])

Maximum

Module Interface

class torchmetrics.aggregation.MaxMetric(nan_strategy='warn', **kwargs)[source]

Aggregate a stream of value into their maximum value.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with aggregated maximum value over all inputs received

Parameters:
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import MaxMetric
>>> metric = MaxMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(3.)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.aggregation import MaxMetric
>>> metric = MaxMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
_images/max-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.aggregation import MaxMetric
>>> metric = MaxMetric()
>>> values = [ ]
>>> for i in range(10):
...     values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
_images/max-2.png

Mean

Module Interface

class torchmetrics.aggregation.MeanMetric(nan_strategy='warn', **kwargs)[source]

Aggregate a stream of value into their mean value.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

  • weight (float or Tensor): a single float or an tensor of float value with arbitary shape (...,). Needs to be broadcastable with the shape of value tensor.

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with aggregated (weighted) mean over all inputs received

Parameters:

nan_strategy (Union[str, float]) –

options:
  • 'error': if any nan values are encounted will give a RuntimeError

  • 'warn': if any nan values are encounted will give a warning and continue

  • 'ignore': all nan values are silently removed

  • a float: if a float is provided will impude any nan values with this value

kwargs: Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torchmetrics.aggregation import MeanMetric
>>> metric = MeanMetric()
>>> metric.update(1)
>>> metric.update(torch.tensor([2, 3]))
>>> metric.compute()
tensor(2.)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.aggregation import MeanMetric
>>> metric = MeanMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
_images/mean-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.aggregation import MeanMetric
>>> metric = MeanMetric()
>>> values = [ ]
>>> for i in range(10):
...     values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
_images/mean-2.png

Minimum

Module Interface

class torchmetrics.aggregation.MinMetric(nan_strategy='warn', **kwargs)[source]

Aggregate a stream of value into their minimum value.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with aggregated minimum value over all inputs received

Parameters:
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import MinMetric
>>> metric = MinMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(1.)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.aggregation import MinMetric
>>> metric = MinMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
_images/min-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.aggregation import MinMetric
>>> metric = MinMetric()
>>> values = [ ]
>>> for i in range(10):
...     values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
_images/min-2.png

Running Mean

Module Interface

class torchmetrics.aggregation.RunningMean(window=5, nan_strategy='warn', **kwargs)[source]

Aggregate a stream of value into their mean over a running window.

Using this metric compared to MeanMetric allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don’t want to wait for the whole training to finish to get epoch level estimates.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with aggregated sum over all inputs received

Parameters:
  • window (int) – The size of the running window.

  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import RunningMean
>>> metric = RunningMean(window=3)
>>> for i in range(6):
...     current_val = metric(tensor([i]))
...     running_val = metric.compute()
...     total_val = tensor(sum(list(range(i+1)))) / (i+1)  # total mean over all samples
...     print(f"{current_val=}, {running_val=}, {total_val=}")
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.)
current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000)
current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.)
current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000)
current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.)
current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000)

Running Sum

Module Interface

class torchmetrics.aggregation.RunningSum(window=5, nan_strategy='warn', **kwargs)[source]

Aggregate a stream of value into their sum over a running window.

Using this metric compared to SumMetric allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don’t want to wait for the whole training to finish to get epoch level estimates.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with aggregated sum over all inputs received

Parameters:
  • window (int) – The size of the running window.

  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import RunningSum
>>> metric = RunningSum(window=3)
>>> for i in range(6):
...     current_val = metric(tensor([i]))
...     running_val = metric.compute()
...     total_val = tensor(sum(list(range(i+1))))  # total sum over all samples
...     print(f"{current_val=}, {running_val=}, {total_val=}")
current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0)
current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1)
current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3)
current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6)
current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10)
current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)

Sum

Module Interface

class torchmetrics.aggregation.SumMetric(nan_strategy='warn', **kwargs)[source]

Aggregate a stream of value into their sum.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with aggregated sum over all inputs received

Parameters:
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encounted will give a RuntimeError - 'warn': if any nan values are encounted will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impude any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor(6.)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()
_images/sum-1.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> values = [ ]
>>> for i in range(10):
...     values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
_images/sum-2.png

Complex Scale-Invariant Signal-to-Noise Ratio (C-SI-SNR)

Module Interface

class torchmetrics.audio.ComplexScaleInvariantSignalNoiseRatio(zero_mean=False, **kwargs)[source]

Calculate Complex scale-invariant signal-to-noise ratio (C-SI-SNR) metric for evaluating quality of audio.

As input to forward and update the metric accepts the following input

  • preds (Tensor): real float tensor with shape (...,frequency,time,2) or complex float tensor with shape (..., frequency,time)

  • target (Tensor): real float tensor with shape (...,frequency,time,2) or complex float tensor with shape (..., frequency,time)

As output of forward and compute the metric returns the following output

  • c_si_snr (Tensor): float scalar tensor with average C-SI-SNR value over samples

Parameters:
Raises:
  • ValueError – If zero_mean is not an bool

  • TypeError – If preds is not the shape (…, frequency, time, 2) (after being converted to real if it is complex). If preds and target does not have the same shape.

Example

>>> import torch
>>> from torch import tensor
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio()
>>> c_si_snr(preds, target)
tensor(-63.4849)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> metric = ComplexScaleInvariantSignalNoiseRatio()
>>> metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2))
>>> fig_, ax_ = metric.plot()
_images/complex_scale_invariant_signal_noise_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> metric = ComplexScaleInvariantSignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(1,257,100,2), torch.rand(1,257,100,2)))
>>> fig_, ax_ = metric.plot(values)
_images/complex_scale_invariant_signal_noise_ratio-2.png

Functional Interface

torchmetrics.functional.audio.complex_scale_invariant_signal_noise_ratio(preds, target, zero_mean=False)[source]

Complex scale-invariant signal-to-noise ratio (C-SI-SNR).

Parameters:
  • preds (Tensor) – real float tensor with shape (...,frequency,time,2) or complex float tensor with shape (..., frequency,time)

  • target (Tensor) – real float tensor with shape (...,frequency,time,2) or complex float tensor with shape (..., frequency,time)

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

Return type:

Tensor

Returns:

Float tensor with shape (...,) of C-SI-SNR values per sample

Raises:

RuntimeError – If preds is not the shape (…,frequency,time,2) (after being converted to real if it is complex). If preds and target does not have the same shape.

Example

>>> import torch
>>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> complex_scale_invariant_signal_noise_ratio(preds, target)
tensor([-63.4849])

Perceptual Evaluation of Speech Quality (PESQ)

Module Interface

class torchmetrics.audio.pesq.PerceptualEvaluationSpeechQuality(fs, mode, n_processes=1, **kwargs)[source]

Calculate Perceptual Evaluation of Speech Quality (PESQ).

It’s a recognized industry standard for audio quality that takes into considerations characteristics such as: audio sharpness, call volume, background noise, clipping, audio interference ect. PESQ returns a score between -0.5 and 4.5 with the higher scores indicating a better quality.

This metric is a wrapper for the pesq package. Note that input will be moved to cpu to perform the metric calculation.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • pesq (Tensor): float tensor with shape (...,) of PESQ value per sample

Note

using this metrics requires you to have pesq install. Either install as pip install torchmetrics[audio] or pip install pesq. pesq will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstall pesq.

Parameters:
  • fs (int) – sampling frequency, should be 16000 or 8000 (Hz)

  • mode (str) – 'wb' (wide-band) or 'nb' (narrow-band)

  • keep_same_device – whether to move the pesq value to the device of preds

  • n_processes (int) – integer specifiying the number of processes to run in parallel for the metric calculation. Only applies to batches of data and if multiprocessing package is installed.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

Example

>>> import torch
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> nb_pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()
_images/perceptual_evaluation_speech_quality-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(8000), torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
_images/perceptual_evaluation_speech_quality-2.png

Functional Interface

torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality(preds, target, fs, mode, keep_same_device=False, n_processes=1)[source]

Calculate Perceptual Evaluation of Speech Quality (PESQ).

It’s a recognized industry standard for audio quality that takes into considerations characteristics such as: audio sharpness, call volume, background noise, clipping, audio interference ect. PESQ returns a score between -0.5 and 4.5 with the higher scores indicating a better quality.

This metric is a wrapper for the pesq package. Note that input will be moved to cpu to perform the metric calculation.

Note

using this metrics requires you to have pesq install. Either install as pip install torchmetrics[audio] or pip install pesq. Note that pesq will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstall pesq.

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • fs (int) – sampling frequency, should be 16000 or 8000 (Hz)

  • mode (str) – 'wb' (wide-band) or 'nb' (narrow-band)

  • keep_same_device (bool) – whether to move the pesq value to the device of preds

  • n_processes (int) – integer specifiying the number of processes to run in parallel for the metric calculation. Only applies to batches of data and if multiprocessing package is installed.

Return type:

Tensor

Returns:

Float tensor with shape (...,) of PESQ values per sample

Raises:

Example

>>> from torch import randn
>>> from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
>>> g = torch.manual_seed(1)
>>> preds = randn(8000)
>>> target = randn(8000)
>>> perceptual_evaluation_speech_quality(preds, target, 8000, 'nb')
tensor(2.2076)
>>> perceptual_evaluation_speech_quality(preds, target, 16000, 'wb')
tensor(1.7359)

Permutation Invariant Training (PIT)

Module Interface

class torchmetrics.audio.PermutationInvariantTraining(metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]

Calculate Permutation invariant training (PIT).

This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (batch_size,num_speakers,...)

  • target (Tensor): float tensor with shape (batch_size,num_speakers,...)

As output of forward and compute the metric returns the following output

  • pesq (Tensor): float scalar tensor with average PESQ value over samples

Parameters:
  • metric_func (Callable) –

    a metric function accept a batch of target and estimate.

    if mode`==’speaker-wise’, then ``metric_func(preds[:, i, …], target[:, j, …])` is called and expected to return a batch of metric tensors (batch,);

    if mode`==’permutation-wise’, then ``metric_func(preds[:, p, …], target[:, :, …])` is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors (batch,);

  • mode (Literal['speaker-wise', 'permutation-wise']) – can be ‘speaker-wise’ or ‘permutation-wise’.

  • eval_func (Literal['max', 'min']) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.

  • kwargs (Any) – Additional keyword arguments for either the metric_func or distributed communication, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
...     mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-2.1065)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
...     mode="speaker-wise", eval_func="max")
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/permutation_invariant_training-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
...     mode="speaker-wise", eval_func="max")
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/permutation_invariant_training-2.png

Functional Interface

torchmetrics.functional.audio.permutation_invariant_training(preds, target, metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]

Calculate Permutation invariant training (PIT).

This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.

Parameters:
  • preds (Tensor) – float tensor with shape (batch_size,num_speakers,...)

  • target (Tensor) – float tensor with shape (batch_size,num_speakers,...)

  • metric_func (Callable) –

    a metric function accept a batch of target and estimate. if mode`==’speaker-wise’, then ``metric_func(preds[:, i, …], target[:, j, …])` is called and expected to return a batch of metric tensors (batch,);

    if mode`==’permutation-wise’, then ``metric_func(preds[:, p, …], target[:, :, …])` is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors (batch,);

  • mode (Literal['speaker-wise', 'permutation-wise']) – can be ‘speaker-wise’ or ‘permutation-wise’.

  • eval_func (Literal['max', 'min']) – the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better.

  • kwargs (Any) – Additional args for metric_func

Return type:

Tuple[Tensor, Tensor]

Returns:

Tuple of two float tensors. First tensor with shape (batch,) contains the best metric value for each sample and second tensor with shape (batch,) contains the best permutation.

Example

>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
>>> # [batch, spk, time]
>>> preds = torch.tensor([[[-0.0579,  0.3560, -0.9604], [-0.1719,  0.3205,  0.2951]]])
>>> target = torch.tensor([[[ 1.0958, -0.1648,  0.5228], [-0.4100,  1.1942, -0.5103]]])
>>> best_metric, best_perm = permutation_invariant_training(
...     preds, target, scale_invariant_signal_distortion_ratio,
...     mode="speaker-wise", eval_func="max")
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579,  0.3560, -0.9604],
         [-0.1719,  0.3205,  0.2951]]])

Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)

Module Interface

class torchmetrics.audio.ScaleInvariantSignalDistortionRatio(zero_mean=False, **kwargs)[source]

Scale-invariant signal-to-distortion ratio (SI-SDR).

The SI-SDR value is in general considered an overall measure of how good a source sound.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • si_sdr (Tensor): float scalar tensor with average SI-SDR value over samples

Parameters:
Raises:

TypeError – if target and preds have a different shape

Example

>>> from torch import tensor
>>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_sdr = ScaleInvariantSignalDistortionRatio()
>>> si_sdr(preds, target)
tensor(18.4030)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
>>> target = torch.randn(5)
>>> preds = torch.randn(5)
>>> metric = ScaleInvariantSignalDistortionRatio()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/scale_invariant_signal_distortion_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
>>> target = torch.randn(5)
>>> preds = torch.randn(5)
>>> metric = ScaleInvariantSignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/scale_invariant_signal_distortion_ratio-2.png

Functional Interface

torchmetrics.functional.audio.scale_invariant_signal_distortion_ratio(preds, target, zero_mean=False)[source]

Scale-invariant signal-to-distortion ratio (SI-SDR).

The SI-SDR value is in general considered an overall measure of how good a source sound.

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • zero_mean (bool) – If to zero mean target and preds or not

Return type:

Tensor

Returns:

Float tensor with shape (...,) of SDR values per sample

Raises:

RuntimeError – If preds and target does not have the same shape

Example

>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> scale_invariant_signal_distortion_ratio(preds, target)
tensor(18.4030)

Scale-Invariant Signal-to-Noise Ratio (SI-SNR)

Module Interface

class torchmetrics.audio.ScaleInvariantSignalNoiseRatio(**kwargs)[source]

Calculate Scale-invariant signal-to-noise ratio (SI-SNR) metric for evaluating quality of audio.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • si_snr (Tensor): float scalar tensor with average SI-SNR value over samples

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

TypeError – if target and preds have a different shape

Example

>>> import torch
>>> from torch import tensor
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = ScaleInvariantSignalNoiseRatio()
>>> si_snr(preds, target)
tensor(15.0918)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> metric = ScaleInvariantSignalNoiseRatio()
>>> metric.update(torch.rand(4), torch.rand(4))
>>> fig_, ax_ = metric.plot()
_images/scale_invariant_signal_noise_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> metric = ScaleInvariantSignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(4), torch.rand(4)))
>>> fig_, ax_ = metric.plot(values)
_images/scale_invariant_signal_noise_ratio-2.png

Functional Interface

torchmetrics.functional.audio.scale_invariant_signal_noise_ratio(preds, target)[source]

Scale-invariant signal-to-noise ratio (SI-SNR).

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

Return type:

Tensor

Returns:

Float tensor with shape (...,) of SI-SNR values per sample

Raises:

RuntimeError – If preds and target does not have the same shape

Example

>>> import torch
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> scale_invariant_signal_noise_ratio(preds, target)
tensor(15.0918)

Short-Time Objective Intelligibility (STOI)

Module Interface

class torchmetrics.audio.stoi.ShortTimeObjectiveIntelligibility(fs, extended=False, **kwargs)[source]

Calculate STOI (Short-Time Objective Intelligibility) metric for evaluating speech signals.

Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due to additive noise, single-/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech intelligibility. Description taken from Cees Taal’s website and for further defails see STOI ref1 and STOI ref2.

This metric is a wrapper for the pystoi package. As the implementation backend implementation only supports calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being moved back to the original device.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • stoi (Tensor): float scalar tensor

Note

using this metrics requires you to have pystoi install. Either install as pip install torchmetrics[audio] or pip install pystoi.

Parameters:
Raises:

ModuleNotFoundError – If pystoi package is not installed

Example

>>> import torch
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> stoi = ShortTimeObjectiveIntelligibility(8000, False)
>>> stoi(preds, target)
tensor(-0.0100)
compute()[source]

Compute metric.

Return type:

Tensor

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> metric = ShortTimeObjectiveIntelligibility(8000, False)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/short_time_objective_intelligibility-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility
>>> metric = ShortTimeObjectiveIntelligibility(8000, False)
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/short_time_objective_intelligibility-2.png
update(preds, target)[source]

Update state with predictions and targets.

Return type:

None

Functional Interface

torchmetrics.functional.audio.stoi.short_time_objective_intelligibility(preds, target, fs, extended=False, keep_same_device=False)[source]

Calculate STOI (Short-Time Objective Intelligibility) metric for evaluating speech signals.

Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due to additive noise, single-/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech intelligibility. Description taken from Cees Taal’s website and for further defails see STOI ref1 and STOI ref2.

This metric is a wrapper for the pystoi package. As the implementation backend implementation only supports calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being moved back to the original device.

Note

using this metrics requires you to have pystoi install. Either install as pip install torchmetrics[audio] or pip install pystoi

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • fs (int) – sampling frequency (Hz)

  • extended (bool) – whether to use the extended STOI described in STOI ref3.

  • keep_same_device (bool) – whether to move the stoi value to the device of preds

Return type:

Tensor

Returns:

stoi value of shape […]

Raises:

Example

>>> import torch
>>> from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> short_time_objective_intelligibility(preds, target, 8000).float()
tensor(-0.0100)

Signal to Distortion Ratio (SDR)

Module Interface

class torchmetrics.audio.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]

Calculate Signal to Distortion Ratio (SDR) metric.

See SDR ref1 and SDR ref2 for details on the metric.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • sdr (Tensor): float scalar tensor with average SDR value over samples

Parameters:
  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> sdr = SignalDistortionRatio()
>>> sdr(preds, target)
tensor(-12.0589)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(signal_distortion_ratio,
...     mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-11.6051)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> metric = SignalDistortionRatio()
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()
_images/signal_distortion_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SignalDistortionRatio
>>> metric = SignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(8000), torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
_images/signal_distortion_ratio-2.png

Functional Interface

torchmetrics.functional.audio.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]

Calculate Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • use_cg_iter (Optional[int]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires that fast-bss-eval is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.

  • filter_length (int) – The length of the distortion filter allowed

  • zero_mean (bool) – When set to True, the mean of all signals is subtracted prior to computation of the metrics

  • load_diag (Optional[float]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero

Return type:

Tensor

Returns:

Float tensor with shape (...,) of SDR values per sample

Raises:

RuntimeError – If preds and target does not have the same shape

Example

>>> import torch
>>> from torchmetrics.functional.audio import signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio)
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
        [0, 1],
        [1, 0],
        [0, 1]])

Signal-to-Noise Ratio (SNR)

Module Interface

class torchmetrics.audio.SignalNoiseRatio(zero_mean=False, **kwargs)[source]

Calculate Signal-to-noise ratio (SNR) meric for evaluating quality of audio.

\[\text{SNR} = \frac{P_{signal}}{P_{noise}}\]

where \(P\) denotes the power of each signal. The SNR metric compares the level of the desired signal to the level of background noise. Therefore, a high value of SNR means that the audio is clear.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

  • target (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • snr (Tensor): float scalar tensor with average SNR value over samples

Parameters:
Raises:

TypeError – if target and preds have a different shape

Example

>>> from torch import tensor
>>> from torchmetrics.audio import SignalNoiseRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SignalNoiseRatio
>>> metric = SignalNoiseRatio()
>>> metric.update(torch.rand(4), torch.rand(4))
>>> fig_, ax_ = metric.plot()
_images/signal_noise_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SignalNoiseRatio
>>> metric = SignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(4), torch.rand(4)))
>>> fig_, ax_ = metric.plot(values)
_images/signal_noise_ratio-2.png

Functional Interface

torchmetrics.functional.audio.signal_noise_ratio(preds, target, zero_mean=False)[source]

Calculate Signal-to-noise ratio (SNR) meric for evaluating quality of audio.

\[\text{SNR} = \frac{P_{signal}}{P_{noise}}\]

where \(P\) denotes the power of each signal. The SNR metric compares the level of the desired signal to the level of background noise. Therefore, a high value of SNR means that the audio is clear.

Parameters:
  • preds (Tensor) – float tensor with shape (...,time)

  • target (Tensor) – float tensor with shape (...,time)

  • zero_mean (bool) – if to zero mean target and preds or not

Return type:

Tensor

Returns:

Float tensor with shape (...,) of SNR values per sample

Raises:

RuntimeError – If preds and target does not have the same shape

Example

>>> from torchmetrics.functional.audio import signal_noise_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> signal_noise_ratio(preds, target)
tensor(16.1805)

Source Aggregated Signal-to-Distortion Ratio (SA-SDR)

Module Interface

class torchmetrics.audio.sdr.SourceAggregatedSignalDistortionRatio(scale_invariant=True, zero_mean=False, **kwargs)[source]

Source-aggregated signal-to-distortion ratio (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (..., spk, time)

  • target (Tensor): float tensor with shape (..., spk, time)

As output of forward and compute the metric returns the following output

  • sa_sdr (Tensor): float scalar tensor with average SA-SDR value over samples

Parameters:
  • preds – float tensor with shape (..., spk, time)

  • target – float tensor with shape (..., spk, time)

  • scale_invariant (bool) – if True, scale the targets of different speakers with the same alpha

  • zero_mean (bool) – If to zero mean target and preds or not

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000) # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> sasdr = SourceAggregatedSignalDistortionRatio()
>>> sasdr(preds, target)
tensor(-41.6579)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio,
...     mode="permutation-wise", eval_func="max")
>>> pit(preds, target)
tensor(-41.2790)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> metric.update(torch.rand(2,8000), torch.rand(2,8000))
>>> fig_, ax_ = metric.plot()
_images/source_aggregated_signal_distortion_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(2,8000), torch.rand(2,8000)))
>>> fig_, ax_ = metric.plot(values)
_images/source_aggregated_signal_distortion_ratio-2.png

Functional Interface

torchmetrics.functional.audio.sdr.source_aggregated_signal_distortion_ratio(preds, target, scale_invariant=True, zero_mean=False)[source]

Source-aggregated signal-to-distortion ratio (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.

Parameters:
  • preds (Tensor) – float tensor with shape (..., spk, time)

  • target (Tensor) – float tensor with shape (..., spk, time)

  • scale_invariant (bool) – if True, scale the targets of different speakers with the same alpha

  • zero_mean (bool) – If to zero mean target and preds or not

Return type:

Tensor

Returns:

SA-SDR with shape (...)

Example

>>> import torch
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000)  # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> source_aggregated_signal_distortion_ratio(preds, target)
tensor(-41.6579)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000)  # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target,
...     source_aggregated_signal_distortion_ratio, mode="permutation-wise")
>>> best_metric
tensor([-37.9511, -41.9124, -42.7369, -42.5155])
>>> best_perm
tensor([[1, 0],
        [1, 0],
        [0, 1],
        [1, 0]])

Speech-to-Reverberation Modulation Energy Ratio (SRMR)

Module Interface

class torchmetrics.audio.srmr.SpeechReverberationModulationEnergyRatio(fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=None, norm=False, fast=False, **kwargs)[source]

Calculate Speech-to-Reverberation Modulation Energy Ratio (SRMR).

SRMR is a non-intrusive metric for speech quality and intelligibility based on a modulation spectral representation of the speech signal. This code is translated from SRMRToolbox and SRMRpy.

As input to forward and update the metric accepts the following input

  • preds (Tensor): float tensor with shape (...,time)

As output of forward and compute the metric returns the following output

  • srmr (Tensor): float scaler tensor

Note

using this metrics requires you to have gammatone and torchaudio installed. Either install as pip install torchmetrics[audio] or pip install torchaudio and pip install git+https://github.com/detly/gammatone.

Note

This implementation is experimental, and might not be consistent with the matlab implementation SRMRToolbox, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have a relatively small inconsistence.

Parameters:
  • fs (int) – the sampling rate

  • n_cochlear_filters (int) – Number of filters in the acoustic filterbank

  • low_freq (float) – determines the frequency cutoff for the corresponding gammatone filterbank.

  • min_cf (float) – Center frequency in Hz of the first modulation filter.

  • max_cf (Optional[float]) – Center frequency in Hz of the last modulation filter. If None is given, then 30 Hz will be used for norm==False, otherwise 128 Hz will be used.

  • norm (bool) – Use modulation spectrum energy normalization

  • fast (bool) – Use the faster version based on the gammatonegram. Note: this argument is inherited from SRMRpy. As the translated code is based to pytorch, setting fast=True may slow down the speed for calculating this metric on GPU.

Raises:

ModuleNotFoundError – If gammatone or torchaudio package is not installed

Example

>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> srmr = SpeechReverberationModulationEnergyRatio(8000)
>>> srmr(preds)
tensor(0.3354)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> metric.update(torch.rand(8000))
>>> fig_, ax_ = metric.plot()
_images/speech_reverberation_modulation_energy_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio
>>> metric = SpeechReverberationModulationEnergyRatio(8000)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
_images/speech_reverberation_modulation_energy_ratio-2.png

Functional Interface

torchmetrics.functional.audio.srmr.speech_reverberation_modulation_energy_ratio(preds, fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=None, norm=False, fast=False)[source]

Calculate Speech-to-Reverberation Modulation Energy Ratio (SRMR).

SRMR is a non-intrusive metric for speech quality and intelligibility based on a modulation spectral representation of the speech signal. This code is translated from SRMRToolbox and SRMRpy.

Parameters:
  • preds (Tensor) – shape (..., time)

  • fs (int) – the sampling rate

  • n_cochlear_filters (int) – Number of filters in the acoustic filterbank

  • low_freq (float) – determines the frequency cutoff for the corresponding gammatone filterbank.

  • min_cf (float) – Center frequency in Hz of the first modulation filter.

  • max_cf (Optional[float]) – Center frequency in Hz of the last modulation filter. If None is given, then 30 Hz will be used for norm==False, otherwise 128 Hz will be used.

  • norm (bool) – Use modulation spectrum energy normalization

  • fast (bool) – Use the faster version based on the gammatonegram. Note: this argument is inherited from SRMRpy. As the translated code is based to pytorch, setting fast=True may slow down the speed for calculating this metric on GPU.

Note

using this metrics requires you to have gammatone and torchaudio installed. Either install as pip install torchmetrics[audio] or pip install torchaudio and pip install git+https://github.com/detly/gammatone.

Note

This implementation is experimental, and might not be consistent with the matlab implementation SRMRToolbox, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have a relatively small inconsistence.

Return type:

Tensor

Returns:

Scalar tensor with srmr value with shape (...)

Raises:

ModuleNotFoundError – If gammatone or torchaudio package is not installed

Example

>>> import torch
>>> from torchmetrics.functional.audio import speech_reverberation_modulation_energy_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> speech_reverberation_modulation_energy_ratio(preds, 8000)
tensor([0.3354], dtype=torch.float64)

Accuracy

Module Interface

class torchmetrics.Accuracy(**kwargs)[source]

Compute Accuracy.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryAccuracy, MulticlassAccuracy and MultilabelAccuracy for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 2, 3])
>>> preds = tensor([0, 2, 1, 3])
>>> accuracy = Accuracy(task="multiclass", num_classes=4)
>>> accuracy(preds, target)
tensor(0.5000)
>>> target = tensor([0, 1, 2])
>>> preds = tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy = Accuracy(task="multiclass", num_classes=3, top_k=2)
>>> accuracy(preds, target)
tensor(0.6667)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryAccuracy
class torchmetrics.classification.BinaryAccuracy(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Accuracy for binary tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • ba (Tensor): If multidim_average is set to global, metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryAccuracy
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryAccuracy()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryAccuracy
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryAccuracy()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryAccuracy
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryAccuracy(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.3333, 0.1667])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = BinaryAccuracy()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/accuracy-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = BinaryAccuracy()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/accuracy-2.png
MulticlassAccuracy
class torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Accuracy for multiclass tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • mca (Tensor): A tensor with the accuracy score whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassAccuracy
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
>>> mca(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassAccuracy
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
>>> mca(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassAccuracy
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.2778])
>>> mca = MulticlassAccuracy(num_classes=3, multidim_average='samplewise', average=None)
>>> mca(preds, target)
tensor([[1.0000, 0.0000, 0.5000],
        [0.0000, 0.3333, 0.5000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/accuracy-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/accuracy-4.png
MultilabelAccuracy
class torchmetrics.classification.MultilabelAccuracy(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Accuracy for multilabel tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

As output to forward and compute the metric returns the following output:

  • mla (Tensor): A tensor with the accuracy score whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelAccuracy
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelAccuracy(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> mla = MultilabelAccuracy(num_labels=3, average=None)
>>> mla(preds, target)
tensor([1.0000, 0.5000, 0.5000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelAccuracy
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelAccuracy(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> mla = MultilabelAccuracy(num_labels=3, average=None)
>>> mla(preds, target)
tensor([1.0000, 0.5000, 0.5000])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelAccuracy
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor(
...     [
...         [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...         [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]],
...     ]
... )
>>> mla = MultilabelAccuracy(num_labels=3, multidim_average='samplewise')
>>> mla(preds, target)
tensor([0.3333, 0.1667])
>>> mla = MultilabelAccuracy(num_labels=3, multidim_average='samplewise', average=None)
>>> mla(preds, target)
tensor([[0.5000, 0.5000, 0.0000],
        [0.0000, 0.0000, 0.5000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelAccuracy
>>> metric = MultilabelAccuracy(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/accuracy-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelAccuracy
>>> metric = MultilabelAccuracy(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/accuracy-6.png

Functional Interface

torchmetrics.functional.classification.accuracy(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute Accuracy. :rtype: Tensor

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_accuracy(), multiclass_accuracy() and multilabel_accuracy() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 2, 3])
>>> preds = tensor([0, 2, 1, 3])
>>> accuracy(preds, target, task="multiclass", num_classes=4)
tensor(0.5000)
>>> target = tensor([0, 1, 2])
>>> preds = tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy(preds, target, task="multiclass", num_classes=3, top_k=2)
tensor(0.6667)
binary_accuracy
torchmetrics.functional.classification.binary_accuracy(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Accuracy for binary tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_accuracy
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_accuracy(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_accuracy
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_accuracy(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_accuracy
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_accuracy(preds, target, multidim_average='samplewise')
tensor([0.3333, 0.1667])
multiclass_accuracy
torchmetrics.functional.classification.multiclass_accuracy(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Accuracy for multiclass tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_accuracy
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_accuracy(preds, target, num_classes=3)
tensor(0.8333)
>>> multiclass_accuracy(preds, target, num_classes=3, average=None)
tensor([0.5000, 1.0000, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_accuracy
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_accuracy(preds, target, num_classes=3)
tensor(0.8333)
>>> multiclass_accuracy(preds, target, num_classes=3, average=None)
tensor([0.5000, 1.0000, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_accuracy
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.5000, 0.2778])
>>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[1.0000, 0.0000, 0.5000],
        [0.0000, 0.3333, 0.5000]])
multilabel_accuracy
torchmetrics.functional.classification.multilabel_accuracy(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Accuracy for multilabel tasks.

\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_accuracy
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_accuracy(preds, target, num_labels=3)
tensor(0.6667)
>>> multilabel_accuracy(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.5000, 0.5000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_accuracy
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_accuracy(preds, target, num_labels=3)
tensor(0.6667)
>>> multilabel_accuracy(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.5000, 0.5000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_accuracy
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_accuracy(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.3333, 0.1667])
>>> multilabel_accuracy(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[0.5000, 0.5000, 0.0000],
        [0.0000, 0.0000, 0.5000]])

AUROC

Module Interface

class torchmetrics.AUROC(**kwargs)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC).

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryAUROC, MulticlassAUROC and MultilabelAUROC for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds = tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = tensor([0, 0, 1, 1, 1])
>>> auroc = AUROC(task="binary")
>>> auroc(preds, target)
tensor(0.5000)
>>> preds = tensor([[0.90, 0.05, 0.05],
...                       [0.05, 0.90, 0.05],
...                       [0.05, 0.05, 0.90],
...                       [0.85, 0.05, 0.10],
...                       [0.10, 0.10, 0.80]])
>>> target = tensor([0, 1, 1, 2, 2])
>>> auroc = AUROC(task="multiclass", num_classes=3)
>>> auroc(preds, target)
tensor(0.7778)
static __new__(cls, task, thresholds=None, num_classes=None, num_labels=None, average='macro', max_fpr=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryAUROC
class torchmetrics.classification.BinaryAUROC(max_fpr=None, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for binary tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

As output to forward and compute the metric returns the following output:

  • b_auroc (Tensor): A single scalar with the auroc score.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • max_fpr (Optional[float]) – If not None, calculates standardized partial AUC over the range [0, max_fpr].

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryAUROC
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> metric = BinaryAUROC(thresholds=None)
>>> metric(preds, target)
tensor(0.5000)
>>> b_auroc = BinaryAUROC(thresholds=5)
>>> b_auroc(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import BinaryAUROC
>>> metric = BinaryAUROC()
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
>>> fig_, ax_ = metric.plot()
_images/auroc-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import BinaryAUROC
>>> metric = BinaryAUROC()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/auroc-2.png
MulticlassAUROC
class torchmetrics.classification.MulticlassAUROC(num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multiclass tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric. By default the reported metric is then the average over all classes, but this behavior can be changed by setting the average argument.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

As output to forward and compute the metric returns the following output:

  • mc_auroc (Tensor): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”macro”|”weighted” then a single scalar is returned.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over classes. Should be one of the following:

    • macro: Calculate score for each class and average them

    • weighted: calculates score for each class and computes weighted average using their support

    • "none" or None: calculates score for each class and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassAUROC
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                 [0.05, 0.75, 0.05, 0.05, 0.05],
...                 [0.05, 0.05, 0.75, 0.05, 0.05],
...                 [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassAUROC(num_classes=5, average="macro", thresholds=None)
>>> metric(preds, target)
tensor(0.5333)
>>> mc_auroc = MulticlassAUROC(num_classes=5, average=None, thresholds=None)
>>> mc_auroc(preds, target)
tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
>>> mc_auroc = MulticlassAUROC(num_classes=5, average="macro", thresholds=5)
>>> mc_auroc(preds, target)
tensor(0.5333)
>>> mc_auroc = MulticlassAUROC(num_classes=5, average=None, thresholds=5)
>>> mc_auroc(preds, target)
tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MulticlassAUROC
>>> metric = MulticlassAUROC(num_classes=3)
>>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
>>> fig_, ax_ = metric.plot()
_images/auroc-3.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MulticlassAUROC
>>> metric = MulticlassAUROC(num_classes=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/auroc-4.png
MultilabelAUROC
class torchmetrics.classification.MultilabelAUROC(num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multilabel tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...) containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

As output to forward and compute the metric returns the following output:

  • ml_auroc (Tensor): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum score over all labels

    • macro: Calculate score for each label and average them

    • weighted: calculates score for each label and computes weighted average using their support

    • "none" or None: calculates score for each label and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelAUROC
>>> preds = tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> ml_auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=None)
>>> ml_auroc(preds, target)
tensor(0.6528)
>>> ml_auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=None)
>>> ml_auroc(preds, target)
tensor([0.6250, 0.5000, 0.8333])
>>> ml_auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=5)
>>> ml_auroc(preds, target)
tensor(0.6528)
>>> ml_auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=5)
>>> ml_auroc(preds, target)
tensor([0.6250, 0.5000, 0.8333])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MultilabelAUROC
>>> metric = MultilabelAUROC(num_labels=3)
>>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
>>> fig_, ax_ = metric.plot()
_images/auroc-5.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MultilabelAUROC
>>> metric = MultilabelAUROC(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
>>> fig_, ax_ = metric.plot(values)
_images/auroc-6.png

Functional Interface

torchmetrics.functional.auroc(preds, target, task, thresholds=None, num_classes=None, num_labels=None, average='macro', max_fpr=None, ignore_index=None, validate_args=True)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC).

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_auroc(), multiclass_auroc() and multilabel_auroc() for the specific details of each argument influence and examples.

Return type:

Optional[Tensor]

Legacy Example:
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> auroc(preds, target, task='binary')
tensor(0.5000)
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
...                       [0.05, 0.90, 0.05],
...                       [0.05, 0.05, 0.90],
...                       [0.85, 0.05, 0.10],
...                       [0.10, 0.10, 0.80]])
>>> target = torch.tensor([0, 1, 1, 2, 2])
>>> auroc(preds, target, task='multiclass', num_classes=3)
tensor(0.7778)
binary_auroc
torchmetrics.functional.classification.binary_auroc(preds, target, max_fpr=None, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for binary tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • max_fpr (Optional[float]) – If not None, calculates standardized partial AUC over the range [0, max_fpr].

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A single scalar with the auroc score

Example

>>> from torchmetrics.functional.classification import binary_auroc
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_auroc(preds, target, thresholds=None)
tensor(0.5000)
>>> binary_auroc(preds, target, thresholds=5)
tensor(0.5000)
multiclass_auroc
torchmetrics.functional.classification.multiclass_auroc(preds, target, num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multiclass tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over classes. Should be one of the following:

    • macro: Calculate score for each class and average them

    • weighted: calculates score for each class and computes weighted average using their support

    • "none" or None: calculates score for each class and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”macro”|”weighted” then a single scalar is returned.

Example

>>> from torchmetrics.functional.classification import multiclass_auroc
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=None)
tensor(0.5333)
>>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=None)
tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
>>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=5)
tensor(0.5333)
>>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=5)
tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
multilabel_auroc
torchmetrics.functional.classification.multilabel_auroc(preds, target, num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]

Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multilabel tasks.

The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum score over all labels

    • macro: Calculate score for each label and average them

    • weighted: calculates score for each label and computes weighted average using their support

    • "none" or None: calculates score for each label and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.

Example

>>> from torchmetrics.functional.classification import multilabel_auroc
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None)
tensor(0.6528)
>>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None)
tensor([0.6250, 0.5000, 0.8333])
>>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=5)
tensor(0.6528)
>>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=5)
tensor([0.6250, 0.5000, 0.8333])

Average Precision

Module Interface

class torchmetrics.AveragePrecision(**kwargs)[source]

Compute the average precision (AP) score.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryAveragePrecision, MulticlassAveragePrecision and MultilabelAveragePrecision for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> pred = tensor([0, 0.1, 0.8, 0.4])
>>> target = tensor([0, 1, 1, 1])
>>> average_precision = AveragePrecision(task="binary")
>>> average_precision(pred, target)
tensor(1.)
>>> pred = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                [0.05, 0.75, 0.05, 0.05, 0.05],
...                [0.05, 0.05, 0.75, 0.05, 0.05],
...                [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> average_precision = AveragePrecision(task="multiclass", num_classes=5, average=None)
>>> average_precision(pred, target)
tensor([1.0000, 1.0000, 0.2500, 0.2500,    nan])
static __new__(cls, task, thresholds=None, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryAveragePrecision
class torchmetrics.classification.BinaryAveragePrecision(thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the average precision (AP) score for binary tasks.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

As output to forward and compute the metric returns the following output:

  • bap (Tensor): A single scalar with the average precision score

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryAveragePrecision
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> metric = BinaryAveragePrecision(thresholds=None)
>>> metric(preds, target)
tensor(0.5833)
>>> bap = BinaryAveragePrecision(thresholds=5)
>>> bap(preds, target)
tensor(0.6667)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import BinaryAveragePrecision
>>> metric = BinaryAveragePrecision()
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)))
>>> fig_, ax_ = metric.plot()
_images/average_precision-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import BinaryAveragePrecision
>>> metric = BinaryAveragePrecision()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(20,), torch.randint(2, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/average_precision-2.png
MulticlassAveragePrecision
class torchmetrics.classification.MulticlassAveragePrecision(num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the average precision (AP) score for multiclass tasks.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric. By default the reported metric is then the average over all classes, but this behavior can be changed by setting the average argument.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

As output to forward and compute the metric returns the following output:

  • mcap (Tensor): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”macro”|”weighted” then a single scalar is returned.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over classes. Should be one of the following:

    • macro: Calculate score for each class and average them

    • weighted: calculates score for each class and computes weighted average using their support

    • "none" or None: calculates score for each class and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassAveragePrecision
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                 [0.05, 0.75, 0.05, 0.05, 0.05],
...                 [0.05, 0.05, 0.75, 0.05, 0.05],
...                 [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=None)
>>> metric(preds, target)
tensor(0.6250)
>>> mcap = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=None)
>>> mcap(preds, target)
tensor([1.0000, 1.0000, 0.2500, 0.2500,    nan])
>>> mcap = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=5)
>>> mcap(preds, target)
tensor(0.5000)
>>> mcap = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=5)
>>> mcap(preds, target)
tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MulticlassAveragePrecision
>>> metric = MulticlassAveragePrecision(num_classes=3)
>>> metric.update(torch.randn(20, 3), torch.randint(3,(20,)))
>>> fig_, ax_ = metric.plot()
_images/average_precision-3.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MulticlassAveragePrecision
>>> metric = MulticlassAveragePrecision(num_classes=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randn(20, 3), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/average_precision-4.png
MultilabelAveragePrecision
class torchmetrics.classification.MultilabelAveragePrecision(num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the average precision (AP) score for multilabel tasks.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...) containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

As output to forward and compute the metric returns the following output:

  • mlap (Tensor): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum score over all labels

    • macro: Calculate score for each label and average them

    • weighted: calculates score for each label and computes weighted average using their support

    • "none" or None: calculates score for each label and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelAveragePrecision
>>> preds = tensor([[0.75, 0.05, 0.35],
...                 [0.45, 0.75, 0.05],
...                 [0.05, 0.55, 0.75],
...                 [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
...                  [0, 0, 0],
...                  [0, 1, 1],
...                  [1, 1, 1]])
>>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=None)
>>> metric(preds, target)
tensor(0.7500)
>>> mlap = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=None)
>>> mlap(preds, target)
tensor([0.7500, 0.5833, 0.9167])
>>> mlap = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=5)
>>> mlap(preds, target)
tensor(0.7778)
>>> mlap = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=5)
>>> mlap(preds, target)
tensor([0.7500, 0.6667, 0.9167])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MultilabelAveragePrecision
>>> metric = MultilabelAveragePrecision(num_labels=3)
>>> metric.update(torch.rand(20,3), torch.randint(2, (20,3)))
>>> fig_, ax_ = metric.plot()
_images/average_precision-5.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MultilabelAveragePrecision
>>> metric = MultilabelAveragePrecision(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(20,3), torch.randint(2, (20,3))))
>>> fig_, ax_ = metric.plot(values)
_images/average_precision-6.png

Functional Interface

torchmetrics.functional.average_precision(preds, target, task, thresholds=None, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True)[source]

Compute the average precision (AP) score.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight: :rtype: Optional[Tensor]

\[AP = \sum{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_average_precision(), multiclass_average_precision() and multilabel_average_precision() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torchmetrics.functional.classification import average_precision
>>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0])
>>> target = torch.tensor([0, 1, 1, 1])
>>> average_precision(pred, target, task="binary")
tensor(1.)
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision(pred, target, task="multiclass", num_classes=5, average=None)
tensor([1.0000, 1.0000, 0.2500, 0.2500,    nan])
binary_average_precision
torchmetrics.functional.classification.binary_average_precision(preds, target, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the average precision (AP) score for binary tasks.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A single scalar with the average precision score

Example

>>> from torchmetrics.functional.classification import binary_average_precision
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_average_precision(preds, target, thresholds=None)
tensor(0.5833)
>>> binary_average_precision(preds, target, thresholds=5)
tensor(0.6667)
multiclass_average_precision
torchmetrics.functional.classification.multiclass_average_precision(preds, target, num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the average precision (AP) score for multiclass tasks.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over classes. Should be one of the following:

    • macro: Calculate score for each class and average them

    • weighted: calculates score for each class and computes weighted average using their support

    • "none" or None: calculates score for each class and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”macro”|”weighted” then a single scalar is returned.

Example

>>> from torchmetrics.functional.classification import multiclass_average_precision
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=None)
tensor(0.6250)
>>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=None)
tensor([1.0000, 1.0000, 0.2500, 0.2500,    nan])
>>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=5)
tensor(0.5000)
>>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=5)
tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000])
multilabel_average_precision
torchmetrics.functional.classification.multilabel_average_precision(preds, target, num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the average precision (AP) score for multilabel tasks.

The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:

\[AP = \sum{n} (R_n - R_{n-1}) P_n\]

where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum score over all labels

    • macro: Calculate score for each label and average them

    • weighted: calculates score for each label and computes weighted average using their support

    • "none" or None: calculates score for each label and applies no reduction

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.

Example

>>> from torchmetrics.functional.classification import multilabel_average_precision
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=None)
tensor(0.7500)
>>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=None)
tensor([0.7500, 0.5833, 0.9167])
>>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=5)
tensor(0.7778)
>>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=5)
tensor([0.7500, 0.6667, 0.9167])

Calibration Error

Module Interface

class torchmetrics.CalibrationError(**kwargs)[source]

Top-label Calibration Error.

The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.

\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]
\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]
\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]

Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary' or 'multiclass'. See the documentation of BinaryCalibrationError and MulticlassCalibrationError for the specific details of each argument influence and examples.

static __new__(cls, task, n_bins=15, norm='l1', num_classes=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryCalibrationError
class torchmetrics.classification.BinaryCalibrationError(n_bins=15, norm='l1', ignore_index=None, validate_args=True, **kwargs)[source]

Top-label Calibration Error for binary tasks.

The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.

\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]
\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]
\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]

Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

As output to forward and compute the metric returns the following output:

  • bce (Tensor): A scalar tensor containing the calibration error

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • n_bins (int) – Number of bins to use when computing the metric.

  • norm (Literal['l1', 'l2', 'max']) – Norm used to compare empirical and expected probability bins.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryCalibrationError
>>> preds = tensor([0.25, 0.25, 0.55, 0.75, 0.75])
>>> target = tensor([0, 0, 1, 1, 1])
>>> metric = BinaryCalibrationError(n_bins=2, norm='l1')
>>> metric(preds, target)
tensor(0.2900)
>>> bce = BinaryCalibrationError(n_bins=2, norm='l2')
>>> bce(preds, target)
tensor(0.2918)
>>> bce = BinaryCalibrationError(n_bins=2, norm='max')
>>> bce(preds, target)
tensor(0.3167)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryCalibrationError
>>> metric = BinaryCalibrationError(n_bins=2, norm='l1')
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/calibration_error-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryCalibrationError
>>> metric = BinaryCalibrationError(n_bins=2, norm='l1')
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/calibration_error-2.png
MulticlassCalibrationError
class torchmetrics.classification.MulticlassCalibrationError(num_classes, n_bins=15, norm='l1', ignore_index=None, validate_args=True, **kwargs)[source]

Top-label Calibration Error for multiclass tasks.

The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.

\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]
\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]
\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]

Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...) containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...) containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mcce (Tensor): A scalar tensor containing the calibration error

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • n_bins (int) – Number of bins to use when computing the metric.

  • norm (Literal['l1', 'l2', 'max']) – Norm used to compare empirical and expected probability bins.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassCalibrationError
>>> preds = tensor([[0.25, 0.20, 0.55],
...                 [0.55, 0.05, 0.40],
...                 [0.10, 0.30, 0.60],
...                 [0.90, 0.05, 0.05]])
>>> target = tensor([0, 1, 2, 0])
>>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1')
>>> metric(preds, target)
tensor(0.2000)
>>> mcce = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l2')
>>> mcce(preds, target)
tensor(0.2082)
>>> mcce = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='max')
>>> mcce(preds, target)
tensor(0.2333)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MulticlassCalibrationError
>>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1')
>>> metric.update(randn(20,3).softmax(dim=-1), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/calibration_error-3.png
>>> from torch import randn, randint
>>> # Example plotting a multiple values
>>> from torchmetrics.classification import MulticlassCalibrationError
>>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1')
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randn(20,3).softmax(dim=-1), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/calibration_error-4.png

Functional Interface

torchmetrics.functional.calibration_error(preds, target, task, n_bins=15, norm='l1', num_classes=None, ignore_index=None, validate_args=True)[source]

Top-label Calibration Error.

The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric. :rtype: Tensor

\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]
\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]
\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]

Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary' or 'multiclass'. See the documentation of binary_calibration_error() and multiclass_calibration_error() for the specific details of each argument influence and examples.

binary_calibration_error
torchmetrics.functional.classification.binary_calibration_error(preds, target, n_bins=15, norm='l1', ignore_index=None, validate_args=True)[source]

Top-label Calibration Error for binary tasks.

The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.

\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]
\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]
\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]

Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • n_bins (int) – Number of bins to use when computing the metric.

  • norm (Literal['l1', 'l2', 'max']) – Norm used to compare empirical and expected probability bins.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torchmetrics.functional.classification import binary_calibration_error
>>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> binary_calibration_error(preds, target, n_bins=2, norm='l1')
tensor(0.2900)
>>> binary_calibration_error(preds, target, n_bins=2, norm='l2')
tensor(0.2918)
>>> binary_calibration_error(preds, target, n_bins=2, norm='max')
tensor(0.3167)
multiclass_calibration_error
torchmetrics.functional.classification.multiclass_calibration_error(preds, target, num_classes, n_bins=15, norm='l1', ignore_index=None, validate_args=True)[source]

Top-label Calibration Error for multiclass tasks.

The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.

\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]
\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]
\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]

Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • n_bins (int) – Number of bins to use when computing the metric.

  • norm (Literal['l1', 'l2', 'max']) – Norm used to compare empirical and expected probability bins.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torchmetrics.functional.classification import multiclass_calibration_error
>>> preds = torch.tensor([[0.25, 0.20, 0.55],
...                       [0.55, 0.05, 0.40],
...                       [0.10, 0.30, 0.60],
...                       [0.90, 0.05, 0.05]])
>>> target = torch.tensor([0, 1, 2, 0])
>>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l1')
tensor(0.2000)
>>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l2')
tensor(0.2082)
>>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='max')
tensor(0.2333)

Cohen Kappa

Module Interface

class torchmetrics.CohenKappa(**kwargs)[source]

Calculate Cohen’s kappa score that measures inter-annotator agreement.

\[\kappa = (p_o - p_e) / (1 - p_e)\]

where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary' or 'multiclass'. See the documentation of BinaryCohenKappa and MulticlassCohenKappa for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> cohenkappa = CohenKappa(task="multiclass", num_classes=2)
>>> cohenkappa(preds, target)
tensor(0.5000)
static __new__(cls, task, threshold=0.5, num_classes=None, weights=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryCohenKappa
class torchmetrics.classification.BinaryCohenKappa(threshold=0.5, ignore_index=None, weights=None, validate_args=True, **kwargs)[source]

Calculate Cohen’s kappa score that measures inter-annotator agreement for binary tasks.

\[\kappa = (p_o - p_e) / (1 - p_e)\]

where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • bck (Tensor): A tensor containing cohen kappa score

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • weights (Optional[Literal['linear', 'quadratic', 'none']]) –

    Weighting type to calculate the score. Choose from:

    • None or 'none': no weighting

    • 'linear': linear weighting

    • 'quadratic': quadratic weighting

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryCohenKappa
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> metric = BinaryCohenKappa()
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryCohenKappa
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> metric = BinaryCohenKappa()
>>> metric(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryCohenKappa
>>> metric = BinaryCohenKappa()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/cohen_kappa-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryCohenKappa
>>> metric = BinaryCohenKappa()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/cohen_kappa-2.png
MulticlassCohenKappa
class torchmetrics.classification.MulticlassCohenKappa(num_classes, ignore_index=None, weights=None, validate_args=True, **kwargs)[source]

Calculate Cohen’s kappa score that measures inter-annotator agreement for multiclass tasks.

\[\kappa = (p_o - p_e) / (1 - p_e)\]

where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Either an int tensor of shape (N, ...)` or float tensor of shape ``(N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mcck (Tensor): A tensor containing cohen kappa score

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • weights (Optional[Literal['linear', 'quadratic', 'none']]) –

    Weighting type to calculate the score. Choose from:

    • None or 'none': no weighting

    • 'linear': linear weighting

    • 'quadratic': quadratic weighting

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassCohenKappa
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassCohenKappa(num_classes=3)
>>> metric(preds, target)
tensor(0.6364)
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassCohenKappa
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassCohenKappa(num_classes=3)
>>> metric(preds, target)
tensor(0.6364)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MulticlassCohenKappa
>>> metric = MulticlassCohenKappa(num_classes=3)
>>> metric.update(randn(20,3).softmax(dim=-1), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/cohen_kappa-3.png
>>> from torch import randn, randint
>>> # Example plotting a multiple values
>>> from torchmetrics.classification import MulticlassCohenKappa
>>> metric = MulticlassCohenKappa(num_classes=3)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randn(20,3).softmax(dim=-1), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/cohen_kappa-4.png

Functional Interface

cohen_kappa
torchmetrics.functional.cohen_kappa(preds, target, task, threshold=0.5, num_classes=None, weights=None, ignore_index=None, validate_args=True)[source]

Calculate Cohen’s kappa score that measures inter-annotator agreement. It is defined as. :rtype: Tensor

\[\kappa = (p_o - p_e) / (1 - p_e)\]

where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary' or 'multiclass'. See the documentation of binary_cohen_kappa() and multiclass_cohen_kappa() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> cohen_kappa(preds, target, task="multiclass", num_classes=2)
tensor(0.5000)
binary_cohen_kappa
torchmetrics.functional.classification.binary_cohen_kappa(preds, target, threshold=0.5, weights=None, ignore_index=None, validate_args=True)[source]

Calculate Cohen’s kappa score that measures inter-annotator agreement for binary tasks.

\[\kappa = (p_o - p_e) / (1 - p_e)\]

where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • weights (Optional[Literal['linear', 'quadratic', 'none']]) –

    Weighting type to calculate the score. Choose from:

    • None or 'none': no weighting

    • 'linear': linear weighting

    • 'quadratic': quadratic weighting

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_cohen_kappa
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_cohen_kappa(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_cohen_kappa
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_cohen_kappa(preds, target)
tensor(0.5000)
multiclass_cohen_kappa
torchmetrics.functional.classification.multiclass_cohen_kappa(preds, target, num_classes, weights=None, ignore_index=None, validate_args=True)[source]

Calculate Cohen’s kappa score that measures inter-annotator agreement for multiclass tasks.

\[\kappa = (p_o - p_e) / (1 - p_e)\]

where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • weights (Optional[Literal['linear', 'quadratic', 'none']]) –

    Weighting type to calculate the score. Choose from:

    • None or 'none': no weighting

    • 'linear': linear weighting

    • 'quadratic': quadratic weighting

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_cohen_kappa
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_cohen_kappa(preds, target, num_classes=3)
tensor(0.6364)
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_cohen_kappa
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_cohen_kappa(preds, target, num_classes=3)
tensor(0.6364)

Confusion Matrix

Module Interface

class torchmetrics.ConfusionMatrix(**kwargs)[source]

Compute the confusion matrix.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryConfusionMatrix, MulticlassConfusionMatrix and MultilabelConfusionMatrix for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(task="binary", num_classes=2)
>>> confmat(preds, target)
tensor([[2, 0],
        [1, 1]])
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
>>> confmat(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
>>> confmat(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryConfusionMatrix
class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for binary tasks.

The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.

For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:

  • \(C_{0, 0}\): True negatives

  • \(C_{0, 1}\): False positives

  • \(C_{1, 0}\): False negatives

  • \(C_{1, 1}\): True positives

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • confusion_matrix (Tensor): A tensor containing a (2, 2) matrix

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> bcm = BinaryConfusionMatrix()
>>> bcm(preds, target)
tensor([[2, 0],
        [1, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01])
>>> bcm = BinaryConfusionMatrix()
>>> bcm(preds, target)
tensor([[2, 0],
        [1, 1]])
plot(val=None, ax=None, add_text=True, labels=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

  • add_text (bool) – if the value of each cell should be added to the plot

  • labels (Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
_images/confusion_matrix-1.png
MulticlassConfusionMatrix
class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for multiclass tasks.

The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.

For multiclass tasks, the confusion matrix is a NxN matrix, where:

  • \(C_{i, i}\) represents the number of true positives for class \(i\)

  • \(\sum_{j=1, j\neq i}^N C_{i, j}\) represents the number of false negatives for class \(i\)

  • \(\sum_{i=1, i\neq j}^N C_{i, j}\) represents the number of false positives for class \(i\)

  • the sum of the remaining cells in the matrix represents the number of true negatives for class \(i\)

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • confusion_matrix: [num_classes, num_classes] matrix

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassConfusionMatrix(num_classes=3)
>>> metric(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassConfusionMatrix(num_classes=3)
>>> metric(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
plot(val=None, ax=None, add_text=True, labels=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

  • add_text (bool) – if the value of each cell should be added to the plot

  • labels (Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
_images/confusion_matrix-2.png
MultilabelConfusionMatrix
class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]

Compute the confusion matrix for multilabel tasks.

The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.

For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion for that label. The structure of each 2x2 matrix is as follows:

  • \(C_{0, 0}\): True negatives

  • \(C_{0, 1}\): False positives

  • \(C_{1, 0}\): False negatives

  • \(C_{1, 1}\): True positives

As input to ‘update’ the metric accepts the following input:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

As output of ‘compute’ the metric returns the following output:

  • confusion matrix: [num_labels,2,2] matrix

Parameters:
  • num_classes – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelConfusionMatrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelConfusionMatrix(num_labels=3)
>>> metric(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelConfusionMatrix(num_labels=3)
>>> metric(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
plot(val=None, ax=None, add_text=True, labels=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Optional[Tensor]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

  • add_text (bool) – if the value of each cell should be added to the plot

  • labels (Optional[List[str]]) – a list of strings, if provided will be added to the plot to indicate the different classes

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
_images/confusion_matrix-3.png

Functional Interface

confusion_matrix
torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_confusion_matrix(), multiclass_confusion_matrix() and multilabel_confusion_matrix() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> from torchmetrics.classification import ConfusionMatrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(task="binary")
>>> confmat(preds, target)
tensor([[2, 0],
        [1, 1]])
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> confmat = ConfusionMatrix(task="multiclass", num_classes=3)
>>> confmat(preds, target)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> confmat = ConfusionMatrix(task="multilabel", num_labels=3)
>>> confmat(preds, target)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
binary_confusion_matrix
torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for binary tasks.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [2, 2] tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_confusion_matrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_confusion_matrix(preds, target)
tensor([[2, 0],
        [1, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_confusion_matrix(preds, target)
tensor([[2, 0],
        [1, 1]])
multiclass_confusion_matrix
torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for multiclass tasks.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [num_classes, num_classes] tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_confusion_matrix(preds, target, num_classes=3)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_confusion_matrix(preds, target, num_classes=3)
tensor([[1, 1, 0],
        [0, 1, 0],
        [0, 0, 1]])
multilabel_confusion_matrix
torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]

Compute the confusion matrix for multilabel tasks.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • normalize (Optional[Literal['true', 'pred', 'all', 'none']]) –

    Normalization mode for confusion matrix. Choose from:

    • None or 'none': no normalization (default)

    • 'true': normalization over the targets (most commonly used)

    • 'pred': normalization over the predictions

    • 'all': normalization over the whole matrix

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

A [num_labels, 2, 2] tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_confusion_matrix(preds, target, num_labels=3)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_confusion_matrix(preds, target, num_labels=3)
tensor([[[1, 0], [0, 1]],
        [[1, 0], [1, 0]],
        [[0, 1], [0, 1]]])

Coverage Error

Module Interface

class torchmetrics.classification.MultilabelCoverageError(num_labels, ignore_index=None, validate_args=True, **kwargs)[source]

Compute Multilabel coverage error.

The score measure how far we need to go through the ranked scores to cover all true labels. The best value is equal to the average number of labels in the target tensor per sample.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mlce (Tensor): A tensor containing the multilabel coverage error.

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example

>>> from torchmetrics.classification import MultilabelCoverageError
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> mlce = MultilabelCoverageError(num_labels=5)
>>> mlce(preds, target)
tensor(3.9000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelCoverageError
>>> metric = MultilabelCoverageError(num_labels=3)
>>> metric.update(rand(20, 3), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/coverage_error-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelCoverageError
>>> metric = MultilabelCoverageError(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(20, 3), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/coverage_error-2.png

Functional Interface

torchmetrics.functional.classification.multilabel_coverage_error(preds, target, num_labels, ignore_index=None, validate_args=True)[source]

Compute multilabel coverage error [1].

The score measure how far we need to go through the ranked scores to cover all true labels. The best value is equal to the average number of labels in the target tensor per sample.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torchmetrics.functional.classification import multilabel_coverage_error
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> multilabel_coverage_error(preds, target, num_labels=5)
tensor(3.9000)

References

[1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US.

Dice

Module Interface

class torchmetrics.Dice(zero_division=0, num_classes=None, threshold=0.5, average='micro', mdmc_average='global', ignore_index=None, top_k=None, multiclass=None, **kwargs)[source]

Compute Dice.

\[\text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.

It is recommend set ignore_index to index of background class.

The reduction method (how the precision scores are aggregated) is controlled by the average parameter, and additionally by the mdmc_average parameter in the multi-dimensional multi-class case.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model (probabilities, logits or labels)

  • target (Tensor): Ground truth values

As output to forward and compute the metric returns the following output:

  • dice (Tensor): A tensor containing the dice score.

    • If average in ['micro', 'macro', 'weighted', 'samples'], a one-element tensor will be returned

    • If average in ['none', None], the shape will be (C,), where C stands for the number of classes

Parameters:
  • num_classes – Number of classes. Necessary for 'macro', and None average methods.

  • threshold – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • zero_division – The value to use for the score if denominator equals zero.

  • average

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

  • mdmc_average

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • top_k – Number of the highest probability or logit score predictions considered finding the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs. Should be left at default (None) for all other types of inputs.

  • multiclass – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If average is none of "micro", "macro", "samples", "none", None.

  • ValueError – If mdmc_average is not one of None, "samplewise", "global".

  • ValueError – If average is set but num_classes is not provided.

  • ValueError – If num_classes is set and ignore_index is not in the range [0, num_classes).

Example

>>> from torch import tensor
>>> from torchmetrics.classification import Dice
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> dice = Dice(average='micro')
>>> dice(preds, target)
tensor(0.2500)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import randint
>>> from torchmetrics.classification import Dice
>>> metric = Dice()
>>> metric.update(randint(2,(10,)), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/dice-1.png
>>> # Example plotting multiple values
>>> from torch import randint
>>> from torchmetrics.classification import Dice
>>> metric = Dice()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2,(10,)), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/dice-2.png

Functional Interface

torchmetrics.functional.dice(preds, target, zero_division=0, average='micro', mdmc_average='global', threshold=0.5, top_k=None, num_classes=None, multiclass=None, ignore_index=None)[source]

Compute Dice.

\[\text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.

It is recommend set ignore_index to index of background class.

The reduction method (how the recall scores are aggregated) is controlled by the average parameter, and additionally by the mdmc_average parameter in the multi-dimensional multi-class case.

Parameters:
  • preds (Tensor) – Predictions from model (probabilities, logits or labels)

  • target (Tensor) – Ground truth values

  • zero_division (int) – The value to use for the score if denominator equals zero

  • average (Optional[str]) –

    Defines the reduction that is applied. Should be one of the following:

    • 'micro' [default]: Calculate the metric globally, across all samples and classes.

    • 'macro': Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).

    • 'weighted': Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn).

    • 'none' or None: Calculate the metric for each class separately, and return the metric for every class.

    • 'samples': Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).

    Note

    What is considered a sample in the multi-dimensional multi-class case depends on the value of mdmc_average.

    Note

    If 'none' and a given class doesn’t occur in the preds or target, the value for the class will be nan.

  • mdmc_average (Optional[str]) –

    Defines how averaging is done for multi-dimensional multi-class inputs (on top of the average parameter). Should be one of the following:

    • None [default]: Should be left unchanged if your data is not multi-dimensional multi-class.

    • 'samplewise': In this case, the statistics are computed separately for each sample on the N axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ... as the N dimension within the sample, and computing the metric for the sample based on that.

    • 'global': In this case the N and ... dimensions of the inputs are flattened into a new N_X sample axis, i.e. the inputs are treated as if they were (N_X, C). From here on the average parameter applies as usual.

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and average=None or 'none', the score for the ignored class will be returned as nan.

  • num_classes (Optional[int]) – Number of classes. Necessary for 'macro', 'weighted' and None average methods.

  • threshold (float) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.

  • top_k (Optional[int]) –

    Number of the highest probability or logit score predictions considered finding the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (None) will be interpreted as 1 for these inputs.

    Should be left at default (None) for all other types of inputs.

  • multiclass (Optional[bool]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be.

Return type:

Tensor

Returns:

The shape of the returned tensor depends on the average parameter

  • If average in ['micro', 'macro', 'weighted', 'samples'], a one-element tensor will be returned

  • If average in ['none', None], the shape will be (C,), where C stands for the number of classes

Raises:
  • ValueError – If average is not one of "micro", "macro", "weighted", "samples", "none" or None

  • ValueError – If mdmc_average is not one of None, "samplewise", "global".

  • ValueError – If average is set but num_classes is not provided.

  • ValueError – If num_classes is set and ignore_index is not in the range [0, num_classes).

Example

>>> from torchmetrics.functional.classification import dice
>>> preds = torch.tensor([2, 0, 2, 1])
>>> target = torch.tensor([1, 1, 2, 0])
>>> dice(preds, target, average='micro')
tensor(0.2500)

Exact Match

Module Interface

class torchmetrics.ExactMatch(**kwargs)[source]

Compute Exact match (also known as subset accuracy).

Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'multiclass' or multilabel. See the documentation of MulticlassExactMatch and MultilabelExactMatch for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='global')
>>> metric(preds, target)
tensor(0.5000)
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([1., 0.])
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

MulticlassExactMatch
class torchmetrics.classification.MulticlassExactMatch(num_classes, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Exact match (also known as subset accuracy) for multiclass tasks.

Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • mcem (Tensor): A tensor whose returned shape depends on the multidim_average argument:

    • If multidim_average is set to global the output will be a scalar tensor

    • If multidim_average is set to samplewise the output will be a tensor of shape (N,)

Parameters:
  • num_classes (int) – Integer specifing the number of labels

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (multidim tensors):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassExactMatch
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassExactMatch(num_classes=3, multidim_average='global')
>>> metric(preds, target)
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassExactMatch
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassExactMatch(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([1., 0.])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassExactMatch
>>> metric = MulticlassExactMatch(num_classes=3)
>>> metric.update(randint(3, (20,5)), randint(3, (20,5)))
>>> fig_, ax_ = metric.plot()
_images/exact_match-1.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassExactMatch
>>> metric = MulticlassExactMatch(num_classes=3)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,5)), randint(3, (20,5))))
>>> fig_, ax_ = metric.plot(values)
_images/exact_match-2.png
MultilabelExactMatch
class torchmetrics.classification.MultilabelExactMatch(num_labels, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Exact match (also known as subset accuracy) for multilabel tasks.

Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor or float tensor of shape (N, C, ..). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...).

As output to forward and compute the metric returns the following output:

  • mlem (Tensor): A tensor whose returned shape depends on the multidim_average argument:

    • If multidim_average is set to global the output will be a scalar tensor

    • If multidim_average is set to samplewise the output will be a tensor of shape (N,)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelExactMatch
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelExactMatch(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelExactMatch
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelExactMatch(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelExactMatch
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelExactMatch(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0., 0.])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelExactMatch
>>> metric = MultilabelExactMatch(num_labels=3)
>>> metric.update(randint(2, (20, 3, 5)), randint(2, (20, 3, 5)))
>>> fig_, ax_ = metric.plot()
_images/exact_match-3.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelExactMatch
>>> metric = MultilabelExactMatch(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3, 5)), randint(2, (20, 3, 5))))
>>> fig_, ax_ = metric.plot(values)
_images/exact_match-4.png

Functional Interface

exact_match
torchmetrics.functional.classification.exact_match(preds, target, task, num_classes=None, num_labels=None, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Exact match (also known as subset accuracy).

Exact Match is a stricter version of accuracy where all classes/labels have to match exactly for the sample to be correctly classified.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'multiclass' or 'multilabel'. See the documentation of multiclass_exact_match() and multilabel_exact_match() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='global')
tensor(0.5000)
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise')
tensor([1., 0.])
multiclass_exact_match
torchmetrics.functional.classification.multiclass_exact_match(preds, target, num_classes, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Exact match (also known as subset accuracy) for multiclass tasks.

Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of labels

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global the output will be a scalar tensor

  • If multidim_average is set to samplewise the output will be a tensor of shape (N,)

Return type:

The returned shape depends on the multidim_average argument

Example (multidim tensors):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_exact_match
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='global')
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_exact_match
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='samplewise')
tensor([1., 0.])
multilabel_exact_match
torchmetrics.functional.classification.multilabel_exact_match(preds, target, num_labels, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Exact match (also known as subset accuracy) for multilabel tasks.

Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global the output will be a scalar tensor

  • If multidim_average is set to samplewise the output will be a tensor of shape (N,)

Return type:

The returned shape depends on the multidim_average argument

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_exact_match
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_exact_match(preds, target, num_labels=3)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_exact_match
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_exact_match(preds, target, num_labels=3)
tensor(0.5000)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_exact_match
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_exact_match(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0., 0.])

F-1 Score

Module Interface

class torchmetrics.F1Score(**kwargs)[source]

Compute F-1 score.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryF1Score, MulticlassF1Score and MultilabelF1Score for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 2, 0, 1, 2])
>>> preds = tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1Score(task="multiclass", num_classes=3)
>>> f1(preds, target)
tensor(0.3333)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryF1Score
class torchmetrics.classification.BinaryF1Score(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute F-1 score for binary tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of 0 is returned.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • bf1s (Tensor): A tensor whose returned shape depends on the multidim_average argument:

    • If multidim_average is set to global, the metric returns a scalar value.

    • If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryF1Score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryF1Score()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryF1Score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryF1Score()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryF1Score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryF1Score(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryF1Score
>>> metric = BinaryF1Score()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/f1_score-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryF1Score
>>> metric = BinaryF1Score()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/f1_score-2.png
MulticlassF1Score
class torchmetrics.classification.MulticlassF1Score(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute F-1 score for multiclass tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • mcf1s (Tensor): A tensor whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • preds – Tensor with predictions

  • target – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassF1Score
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassF1Score(num_classes=3)
>>> metric(preds, target)
tensor(0.7778)
>>> mcf1s = MulticlassF1Score(num_classes=3, average=None)
>>> mcf1s(preds, target)
tensor([0.6667, 0.6667, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassF1Score
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassF1Score(num_classes=3)
>>> metric(preds, target)
tensor(0.7778)
>>> mcf1s = MulticlassF1Score(num_classes=3, average=None)
>>> mcf1s(preds, target)
tensor([0.6667, 0.6667, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassF1Score
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.4333, 0.2667])
>>> mcf1s = MulticlassF1Score(num_classes=3, multidim_average='samplewise', average=None)
>>> mcf1s(preds, target)
tensor([[0.8000, 0.0000, 0.5000],
        [0.0000, 0.4000, 0.4000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassF1Score
>>> metric = MulticlassF1Score(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/f1_score-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassF1Score
>>> metric = MulticlassF1Score(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/f1_score-4.png
MultilabelF1Score
class torchmetrics.classification.MultilabelF1Score(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute F-1 score for multilabel tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...).

As output to forward and compute the metric returns the following output:

  • mlf1s (Tensor): A tensor whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)`

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelF1Score
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelF1Score(num_labels=3)
>>> metric(preds, target)
tensor(0.5556)
>>> mlf1s = MultilabelF1Score(num_labels=3, average=None)
>>> mlf1s(preds, target)
tensor([1.0000, 0.0000, 0.6667])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelF1Score
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelF1Score(num_labels=3)
>>> metric(preds, target)
tensor(0.5556)
>>> mlf1s = MultilabelF1Score(num_labels=3, average=None)
>>> mlf1s(preds, target)
tensor([1.0000, 0.0000, 0.6667])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelF1Score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.4444, 0.0000])
>>> mlf1s = MultilabelF1Score(num_labels=3, multidim_average='samplewise', average=None)
>>> mlf1s(preds, target)
tensor([[0.6667, 0.6667, 0.0000],
        [0.0000, 0.0000, 0.0000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelF1Score
>>> metric = MultilabelF1Score(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/f1_score-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelF1Score
>>> metric = MultilabelF1Score(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/f1_score-6.png

Functional Interface

f1_score
torchmetrics.functional.f1_score(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute F-1 score. :rtype: Tensor

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_f1_score(), multiclass_f1_score() and multilabel_f1_score() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 2, 0, 1, 2])
>>> preds = tensor([0, 2, 1, 0, 0, 1])
>>> f1_score(preds, target, task="multiclass", num_classes=3)
tensor(0.3333)
binary_f1_score
torchmetrics.functional.classification.binary_f1_score(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute F-1 score for binary tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_f1_score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_f1_score(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_f1_score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_f1_score(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_f1_score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_f1_score(preds, target, multidim_average='samplewise')
tensor([0.5000, 0.0000])
multiclass_f1_score
torchmetrics.functional.classification.multiclass_f1_score(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute F-1 score for multiclass tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_f1_score
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_f1_score(preds, target, num_classes=3)
tensor(0.7778)
>>> multiclass_f1_score(preds, target, num_classes=3, average=None)
tensor([0.6667, 0.6667, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_f1_score
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_f1_score(preds, target, num_classes=3)
tensor(0.7778)
>>> multiclass_f1_score(preds, target, num_classes=3, average=None)
tensor([0.6667, 0.6667, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_f1_score
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.4333, 0.2667])
>>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.8000, 0.0000, 0.5000],
        [0.0000, 0.4000, 0.4000]])
multilabel_f1_score
torchmetrics.functional.classification.multilabel_f1_score(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute F-1 score for multilabel tasks.

\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_f1_score
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_f1_score(preds, target, num_labels=3)
tensor(0.5556)
>>> multilabel_f1_score(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.6667])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_f1_score
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_f1_score(preds, target, num_labels=3)
tensor(0.5556)
>>> multilabel_f1_score(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.6667])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_f1_score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.4444, 0.0000])
>>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[0.6667, 0.6667, 0.0000],
        [0.0000, 0.0000, 0.0000]])

F-Beta Score

Module Interface

class torchmetrics.FBetaScore(**kwargs)[source]

Compute F-score metric.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryFBetaScore, MulticlassFBetaScore and MultilabelFBetaScore for the specific details of each argument influence and examples.

Legcy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 2, 0, 1, 2])
>>> preds = tensor([0, 2, 1, 0, 0, 1])
>>> f_beta = FBetaScore(task="multiclass", num_classes=3, beta=0.5)
>>> f_beta(preds, target)
tensor(0.3333)
static __new__(cls, task, beta=1.0, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryFBetaScore
class torchmetrics.classification.BinaryFBetaScore(beta, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute F-score metric for binary tasks.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of 0 is returned.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • bfbs (Tensor): A tensor whose returned shape depends on the multidim_average argument:

    • If multidim_average is set to global the output will be a scalar tensor

    • If multidim_average is set to samplewise the output will be a tensor of shape (N,) consisting of a scalar value per sample.

Parameters:
  • beta (float) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryFBetaScore
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryFBetaScore(beta=2.0)
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryFBetaScore
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryFBetaScore(beta=2.0)
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryFBetaScore
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryFBetaScore(beta=2.0, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5882, 0.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryFBetaScore
>>> metric = BinaryFBetaScore(beta=2.0)
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/fbeta_score-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryFBetaScore
>>> metric = BinaryFBetaScore(beta=2.0)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/fbeta_score-2.png
MulticlassFBetaScore
class torchmetrics.classification.MulticlassFBetaScore(beta, num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute F-score metric for multiclass tasks.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • mcfbs (Tensor): A tensor whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • beta (float) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3)
>>> metric(preds, target)
tensor(0.7963)
>>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None)
>>> mcfbs(preds, target)
tensor([0.5556, 0.8333, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3)
>>> metric(preds, target)
tensor(0.7963)
>>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None)
>>> mcfbs(preds, target)
tensor([0.5556, 0.8333, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.4697, 0.2706])
>>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise', average=None)
>>> mcfbs(preds, target)
tensor([[0.9091, 0.0000, 0.5000],
        [0.0000, 0.3571, 0.4545]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/fbeta_score-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassFBetaScore
>>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/fbeta_score-4.png
MultilabelFBetaScore
class torchmetrics.classification.MultilabelFBetaScore(beta, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute F-score metric for multilabel tasks.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...).

As output to forward and compute the metric returns the following output:

  • mlfbs (Tensor): A tensor whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • beta (float) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3)
>>> metric(preds, target)
tensor(0.6111)
>>> mlfbs = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None)
>>> mlfbs(preds, target)
tensor([1.0000, 0.0000, 0.8333])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3)
>>> metric(preds, target)
tensor(0.6111)
>>> mlfbs = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None)
>>> mlfbs(preds, target)
tensor([1.0000, 0.0000, 0.8333])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5556, 0.0000])
>>> mlfbs = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise', average=None)
>>> mlfbs(preds, target)
tensor([[0.8333, 0.8333, 0.0000],
        [0.0000, 0.0000, 0.0000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/fbeta_score-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelFBetaScore
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/fbeta_score-6.png

Functional Interface

fbeta_score
torchmetrics.functional.fbeta_score(preds, target, task, beta=1.0, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute F-score metric. :rtype: Tensor

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_fbeta_score(), multiclass_fbeta_score() and multilabel_fbeta_score() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 2, 0, 1, 2])
>>> preds = tensor([0, 2, 1, 0, 0, 1])
>>> fbeta_score(preds, target, task="multiclass", num_classes=3, beta=0.5)
tensor(0.3333)
binary_fbeta_score
torchmetrics.functional.classification.binary_fbeta_score(preds, target, beta, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute F-score metric for binary tasks.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • beta (float) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_fbeta_score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_fbeta_score(preds, target, beta=2.0)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_fbeta_score
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_fbeta_score(preds, target, beta=2.0)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_fbeta_score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_fbeta_score(preds, target, beta=2.0, multidim_average='samplewise')
tensor([0.5882, 0.0000])
multiclass_fbeta_score
torchmetrics.functional.classification.multiclass_fbeta_score(preds, target, beta, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute F-score metric for multiclass tasks.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • beta (float) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_fbeta_score
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3)
tensor(0.7963)
>>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None)
tensor([0.5556, 0.8333, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_fbeta_score
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3)
tensor(0.7963)
>>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None)
tensor([0.5556, 0.8333, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_fbeta_score
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise')
tensor([0.4697, 0.2706])
>>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.9091, 0.0000, 0.5000],
        [0.0000, 0.3571, 0.4545]])
multilabel_fbeta_score
torchmetrics.functional.classification.multilabel_fbeta_score(preds, target, beta, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute F-score metric for multilabel tasks.

\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • beta (float) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weight

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_fbeta_score
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3)
tensor(0.6111)
>>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.8333])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_fbeta_score
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3)
tensor(0.6111)
>>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.8333])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_fbeta_score
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise')
tensor([0.5556, 0.0000])
>>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise', average=None)
tensor([[0.8333, 0.8333, 0.0000],
        [0.0000, 0.0000, 0.0000]])

Group Fairness

Module Interface

BinaryFairness
class torchmetrics.classification.BinaryFairness(num_groups, task='all', threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]

Computes Demographic parity and Equal opportunity ratio for binary classification problems.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • groups (int tensor): (N, ...). The group identifiers should be 0, 1, ..., (num_groups - 1).

  • target (int tensor): (N, ...).

The additional dimensions are flatted along the batch dimension.

This class computes the ratio between positivity rates and true positives rates for different groups. If more than two groups are present, the disparity between the lowest and highest group is reported. A disparity between positivity rates indicates a potential violation of demographic parity, and between true positive rates indicates a potential violation of equal opportunity.

The lowest rate is divided by the highest, so a lower value means more discrimination against the numerator. In the results this is also indicated as the key of dict is {metric}_{identifier_low_group}_{identifier_high_group}.

Parameters:
  • num_groups (int) – The number of groups.

  • task (Literal['demographic_parity', 'equal_opportunity', 'all']) – The task to compute. Can be either demographic_parity or equal_oppotunity or all.

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

The metric returns a dict where the key identifies the metric and groups with the lowest and highest true positives rates as follows: {metric}__{identifier_low_group}_{identifier_high_group}. The value is a tensor with the disparity rate.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryFairness
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryFairness(2)
>>> metric(preds, target, groups)
{'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)}
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryFairness
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryFairness(2)
>>> metric(preds, target, groups)
{'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)}
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> _ = torch.manual_seed(42)
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> metric.update(torch.rand(20), torch.randint(2,(20,)), torch.randint(2,(20,)))
>>> fig_, ax_ = metric.plot()
_images/group_fairness-1.png
>>> import torch
>>> _ = torch.manual_seed(42)
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(20), torch.randint(2,(20,)), torch.ones(20).long()))
>>> fig_, ax_ = metric.plot(values)
_images/group_fairness-2.png
BinaryGroupStatRates
class torchmetrics.classification.BinaryGroupStatRates(num_groups, threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]

Computes the true/false positives and true/false negatives rates for binary classification by group.

Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...).

  • groups (int tensor): (N, ...). The group identifiers should be 0, 1, ..., (num_groups - 1).

The additional dimensions are flatted along the batch dimension.

Parameters:
  • num_groups (int) – The number of groups.

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

The metric returns a dict with a group identifier as key and a tensor with the tp, fp, tn and fn rates as value.

Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryGroupStatRates
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryGroupStatRates(num_groups=2)
>>> metric(preds, target, groups)
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryGroupStatRates
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> metric = BinaryGroupStatRates(num_groups=2)
>>> metric(preds, target, groups)
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}

Functional Interface

binary_fairness
torchmetrics.functional.classification.binary_fairness(preds, target, groups, task='all', threshold=0.5, ignore_index=None, validate_args=True)[source]

Compute either Demographic parity and Equal opportunity ratio for binary classification problems.

This is done by setting the task argument to either 'demographic_parity', 'equal_opportunity' or all. See the documentation of demographic_parity() and equal_opportunity() for the specific details of each argument influence and examples.

Parameters:
  • preds (Tensor) – Tensor with predictions.

  • target (Tensor) – Tensor with true labels (not required for demographic_parity).

  • groups (Tensor) – Tensor with group identifiers. The group identifiers should be 0, 1, ..., (num_groups - 1).

  • task (Literal['demographic_parity', 'equal_opportunity', 'all']) – The task to compute. Can be either demographic_parity or equal_oppotunity or all.

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Dict[str, Tensor]

demographic_parity
torchmetrics.functional.classification.demographic_parity(preds, groups, threshold=0.5, ignore_index=None, validate_args=True)[source]

Demographic parity compares the positivity rates between all groups.

If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest positivity rate is divided by the highest, so a lower value means more discrimination against the numerator. In the results this is also indicated as the key of dict is DP_{identifier_low_group}_{identifier_high_group}.

\[\text{DP} = \dfrac{\min_a PR_a}{\max_a PR_a}.\]

where \(\text{PR}\) represents the positivity rate for group \(\text{a}\).

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • groups (int tensor): (N, ...). The group identifiers should be 0, 1, ..., (num_groups - 1).

  • target (int tensor): (N, ...).

The additional dimensions are flatted along the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions.

  • groups (Tensor) – Tensor with group identifiers. The group identifiers should be 0, 1, ..., (num_groups - 1).

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Dict[str, Tensor]

Returns:

The metric returns a dict where the key identifies the group with the lowest and highest positivity rates as follows: DP_{identifier_low_group}_{identifier_high_group}. The value is a tensor with the DP rate.

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import demographic_parity
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> demographic_parity(preds, groups)
{'DP_0_1': tensor(0.)}
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import demographic_parity
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> demographic_parity(preds, groups)
{'DP_0_1': tensor(0.)}
equal_opportunity
torchmetrics.functional.classification.equal_opportunity(preds, target, groups, threshold=0.5, ignore_index=None, validate_args=True)[source]

Equal opportunity compares the true positive rates between all groups.

If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest true positive rate is divided by the highest, so a lower value means more discrimination against the numerator. In the results this is also indicated as the key of dict is EO_{identifier_low_group}_{identifier_high_group}.

\[\text{DP} = \dfrac{\min_a TPR_a}{\max_a TPR_a}.\]

where \(\text{TPR}\) represents the true positives rate for group \(\text{a}\).

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...).

  • groups (int tensor): (N, ...). The group identifiers should be 0, 1, ..., (num_groups - 1).

The additional dimensions are flatted along the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions.

  • target (Tensor) – Tensor with true labels.

  • groups (Tensor) – Tensor with group identifiers. The group identifiers should be 0, 1, ..., (num_groups - 1).

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Dict[str, Tensor]

Returns:

The metric returns a dict where the key identifies the group with the lowest and highest true positives rates as follows: EO_{identifier_low_group}_{identifier_high_group}. The value is a tensor with the EO rate.

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import equal_opportunity
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> equal_opportunity(preds, target, groups)
{'EO_0_1': tensor(0.)}
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import equal_opportunity
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> equal_opportunity(preds, target, groups)
{'EO_0_1': tensor(0.)}
binary_groups_stat_rates
torchmetrics.functional.classification.binary_groups_stat_rates(preds, target, groups, num_groups, threshold=0.5, ignore_index=None, validate_args=True)[source]

Compute the true/false positives and true/false negatives rates for binary classification by group.

Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...).

  • groups (int tensor): (N, ...). The group identifiers should be 0, 1, ..., (num_groups - 1).

The additional dimensions are flatted along the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions.

  • target (Tensor) – Tensor with true labels.

  • groups (Tensor) – Tensor with group identifiers. The group identifiers should be 0, 1, ..., (num_groups - 1).

  • num_groups (int) – The number of groups.

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Dict[str, Tensor]

Returns:

The metric returns a dict with a group identifier as key and a tensor with the tp, fp, tn and fn rates as value.

Example (preds is int tensor):
>>> from torchmetrics.functional.classification import binary_groups_stat_rates
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0, 1, 0, 1, 0, 1])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> binary_groups_stat_rates(preds, target, groups, 2)
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_groups_stat_rates
>>> target = torch.tensor([0, 1, 0, 1, 0, 1])
>>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92])
>>> groups = torch.tensor([0, 1, 0, 1, 0, 1])
>>> binary_groups_stat_rates(preds, target, groups, 2)
{'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}

Hamming Distance

Module Interface

class torchmetrics.HammingDistance(**kwargs)[source]

Compute the average Hamming distance (also known as Hamming loss).

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryHammingDistance, MulticlassHammingDistance and MultilabelHammingDistance for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([[0, 1], [1, 1]])
>>> preds = tensor([[0, 1], [0, 1]])
>>> hamming_distance = HammingDistance(task="multilabel", num_labels=2)
>>> hamming_distance(preds, target)
tensor(0.2500)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryHammingDistance
class torchmetrics.classification.BinaryHammingDistance(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute the average Hamming distance (also known as Hamming loss) for binary tasks.

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • bhd (Tensor): A tensor whose returned shape depends on the multidim_average arguments:

    • If multidim_average is set to global, the metric returns a scalar value.

    • If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryHammingDistance
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryHammingDistance()
>>> metric(preds, target)
tensor(0.3333)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryHammingDistance
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryHammingDistance()
>>> metric(preds, target)
tensor(0.3333)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryHammingDistance
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryHammingDistance(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.6667, 0.8333])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryHammingDistance
>>> metric = BinaryHammingDistance()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/hamming_distance-1.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryHammingDistance
>>> metric = BinaryHammingDistance()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/hamming_distance-2.png
MulticlassHammingDistance
class torchmetrics.classification.MulticlassHammingDistance(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute the average Hamming distance (also known as Hamming loss) for multiclass tasks.

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • mchd (Tensor): A tensor whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassHammingDistance(num_classes=3)
>>> metric(preds, target)
tensor(0.1667)
>>> mchd = MulticlassHammingDistance(num_classes=3, average=None)
>>> mchd(preds, target)
tensor([0.5000, 0.0000, 0.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassHammingDistance(num_classes=3)
>>> metric(preds, target)
tensor(0.1667)
>>> mchd = MulticlassHammingDistance(num_classes=3, average=None)
>>> mchd(preds, target)
tensor([0.5000, 0.0000, 0.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.7222])
>>> mchd = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise', average=None)
>>> mchd(preds, target)
tensor([[0.0000, 1.0000, 0.5000],
        [1.0000, 0.6667, 0.5000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> metric = MulticlassHammingDistance(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/hamming_distance-3.png
>>> # Example plotting a multiple values per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassHammingDistance
>>> metric = MulticlassHammingDistance(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/hamming_distance-4.png
MultilabelHammingDistance
class torchmetrics.classification.MultilabelHammingDistance(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute the average Hamming distance (also known as Hamming loss) for multilabel tasks.

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...).

As output to forward and compute the metric returns the following output:

  • mlhd (Tensor): A tensor whose returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelHammingDistance(num_labels=3)
>>> metric(preds, target)
tensor(0.3333)
>>> mlhd = MultilabelHammingDistance(num_labels=3, average=None)
>>> mlhd(preds, target)
tensor([0.0000, 0.5000, 0.5000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelHammingDistance(num_labels=3)
>>> metric(preds, target)
tensor(0.3333)
>>> mlhd = MultilabelHammingDistance(num_labels=3, average=None)
>>> mlhd(preds, target)
tensor([0.0000, 0.5000, 0.5000])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.6667, 0.8333])
>>> mlhd = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise', average=None)
>>> mlhd(preds, target)
tensor([[0.5000, 0.5000, 1.0000],
        [1.0000, 1.0000, 0.5000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> metric = MultilabelHammingDistance(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/hamming_distance-5.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelHammingDistance
>>> metric = MultilabelHammingDistance(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/hamming_distance-6.png

Functional Interface

hamming_distance
torchmetrics.functional.hamming_distance(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute the average Hamming distance (also known as Hamming loss). :rtype: Tensor

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_hamming_distance(), multiclass_hamming_distance() and multilabel_hamming_distance() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([[0, 1], [1, 1]])
>>> preds = tensor([[0, 1], [0, 1]])
>>> hamming_distance(preds, target, task="binary")
tensor(0.2500)
binary_hamming_distance
torchmetrics.functional.classification.binary_hamming_distance(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute the average Hamming distance (also known as Hamming loss) for binary tasks.

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_hamming_distance
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_hamming_distance(preds, target)
tensor(0.3333)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_hamming_distance
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_hamming_distance(preds, target)
tensor(0.3333)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_hamming_distance
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_hamming_distance(preds, target, multidim_average='samplewise')
tensor([0.6667, 0.8333])
multiclass_hamming_distance
torchmetrics.functional.classification.multiclass_hamming_distance(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute the average Hamming distance (also known as Hamming loss) for multiclass tasks.

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_hamming_distance
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_hamming_distance(preds, target, num_classes=3)
tensor(0.1667)
>>> multiclass_hamming_distance(preds, target, num_classes=3, average=None)
tensor([0.5000, 0.0000, 0.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_hamming_distance
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_hamming_distance(preds, target, num_classes=3)
tensor(0.1667)
>>> multiclass_hamming_distance(preds, target, num_classes=3, average=None)
tensor([0.5000, 0.0000, 0.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_hamming_distance
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.5000, 0.7222])
>>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.0000, 1.0000, 0.5000],
        [1.0000, 0.6667, 0.5000]])
multilabel_hamming_distance
torchmetrics.functional.classification.multilabel_hamming_distance(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute the average Hamming distance (also known as Hamming loss) for multilabel tasks.

\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]

Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_hamming_distance
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_hamming_distance(preds, target, num_labels=3)
tensor(0.3333)
>>> multilabel_hamming_distance(preds, target, num_labels=3, average=None)
tensor([0.0000, 0.5000, 0.5000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_hamming_distance
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_hamming_distance(preds, target, num_labels=3)
tensor(0.3333)
>>> multilabel_hamming_distance(preds, target, num_labels=3, average=None)
tensor([0.0000, 0.5000, 0.5000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_hamming_distance
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.6667, 0.8333])
>>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[0.5000, 0.5000, 1.0000],
        [1.0000, 1.0000, 0.5000]])

Hinge Loss

Module Interface

class torchmetrics.HingeLoss(**kwargs)[source]

Compute the mean Hinge loss typically used for Support Vector Machines (SVMs).

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary' or 'multiclass'. See the documentation of BinaryHingeLoss and MulticlassHingeLoss for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 1])
>>> preds = tensor([0.5, 0.7, 0.1])
>>> hinge = HingeLoss(task="binary")
>>> hinge(preds, target)
tensor(0.9000)
>>> target = tensor([0, 1, 2])
>>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = HingeLoss(task="multiclass", num_classes=3)
>>> hinge(preds, target)
tensor(1.5551)
>>> target = tensor([0, 1, 2])
>>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all")
>>> hinge(preds, target)
tensor([1.3743, 1.1945, 1.2359])
static __new__(cls, task, num_classes=None, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryHingeLoss
class torchmetrics.classification.BinaryHingeLoss(squared=False, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for binary tasks.

\[\text{Hinge loss} = \max(0, 1 - y \times \hat{y})\]

Where \(y \in {-1, 1}\) is the target, and \(\hat{y} \in \mathbb{R}\) is the prediction.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • bhl (Tensor): A tensor containing the hinge loss.

Parameters:
  • squared (bool) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.classification import BinaryHingeLoss
>>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> bhl = BinaryHingeLoss()
>>> bhl(preds, target)
tensor(0.6900)
>>> bhl = BinaryHingeLoss(squared=True)
>>> bhl(preds, target)
tensor(0.6905)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryHingeLoss
>>> metric = BinaryHingeLoss()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/hinge_loss-1.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryHingeLoss
>>> metric = BinaryHingeLoss()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/hinge_loss-2.png
MulticlassHingeLoss
class torchmetrics.classification.MulticlassHingeLoss(num_classes, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True, **kwargs)[source]

Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for multiclass tasks.

The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:

\[\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)\]

Where \(y \in {0, ..., \mathrm{C}}\) is the target class (where \(\mathrm{C}\) is the number of classes), and \(\hat{y} \in \mathbb{R}^\mathrm{C}\) is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mchl (Tensor): A tensor containing the multi-class hinge loss.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • squared (bool) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.

  • multiclass_mode (Literal['crammer-singer', 'one-vs-all']) – Determines how to compute the metric

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.classification import MulticlassHingeLoss
>>> preds = torch.tensor([[0.25, 0.20, 0.55],
...                       [0.55, 0.05, 0.40],
...                       [0.10, 0.30, 0.60],
...                       [0.90, 0.05, 0.05]])
>>> target = torch.tensor([0, 1, 2, 0])
>>> mchl = MulticlassHingeLoss(num_classes=3)
>>> mchl(preds, target)
tensor(0.9125)
>>> mchl = MulticlassHingeLoss(num_classes=3, squared=True)
>>> mchl(preds, target)
tensor(1.1131)
>>> mchl = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all')
>>> mchl(preds, target)
tensor([0.8750, 1.1250, 1.1000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value per class
>>> from torch import randint, randn
>>> from torchmetrics.classification import MulticlassHingeLoss
>>> metric = MulticlassHingeLoss(num_classes=3)
>>> metric.update(randn(20, 3), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/hinge_loss-3.png
>>> # Example plotting a multiple values per class
>>> from torch import randint, randn
>>> from torchmetrics.classification import MulticlassHingeLoss
>>> metric = MulticlassHingeLoss(num_classes=3)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randn(20, 3), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/hinge_loss-4.png

Functional Interface

torchmetrics.functional.hinge_loss(preds, target, task, num_classes=None, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True)[source]

Compute the mean Hinge loss typically used for Support Vector Machines (SVMs).

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary' or 'multiclass'. See the documentation of binary_hinge_loss() and multiclass_hinge_loss() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> target = tensor([0, 1, 1])
>>> preds = tensor([0.5, 0.7, 0.1])
>>> hinge_loss(preds, target, task="binary")
tensor(0.9000)
>>> target = tensor([0, 1, 2])
>>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge_loss(preds, target, task="multiclass", num_classes=3)
tensor(1.5551)
>>> target = tensor([0, 1, 2])
>>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]])
>>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all")
tensor([1.3743, 1.1945, 1.2359])
binary_hinge_loss
torchmetrics.functional.classification.binary_hinge_loss(preds, target, squared=False, ignore_index=None, validate_args=False)[source]

Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for binary tasks.

\[\text{Hinge loss} = \max(0, 1 - y \times \hat{y})\]

Where \(y \in {-1, 1}\) is the target, and \(\hat{y} \in \mathbb{R}\) is the prediction.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • squared (bool) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_hinge_loss
>>> preds = tensor([0.25, 0.25, 0.55, 0.75, 0.75])
>>> target = tensor([0, 0, 1, 1, 1])
>>> binary_hinge_loss(preds, target)
tensor(0.6900)
>>> binary_hinge_loss(preds, target, squared=True)
tensor(0.6905)
multiclass_hinge_loss
torchmetrics.functional.classification.multiclass_hinge_loss(preds, target, num_classes, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=False)[source]

Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for multiclass tasks.

The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:

\[\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)\]

Where \(y \in {0, ..., \mathrm{C}}\) is the target class (where \(\mathrm{C}\) is the number of classes), and \(\hat{y} \in \mathbb{R}^\mathrm{C}\) is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • squared (bool) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.

  • multiclass_mode (Literal['crammer-singer', 'one-vs-all']) – Determines how to compute the metric

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_hinge_loss
>>> preds = tensor([[0.25, 0.20, 0.55],
...                 [0.55, 0.05, 0.40],
...                 [0.10, 0.30, 0.60],
...                 [0.90, 0.05, 0.05]])
>>> target = tensor([0, 1, 2, 0])
>>> multiclass_hinge_loss(preds, target, num_classes=3)
tensor(0.9125)
>>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True)
tensor(1.1131)
>>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all')
tensor([0.8750, 1.1250, 1.1000])

Jaccard Index

Module Interface

class torchmetrics.JaccardIndex(**kwargs)[source]

Calculate the Jaccard index for multilabel tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryJaccardIndex, MulticlassJaccardIndex and MultilabelJaccardIndex for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import randint, tensor
>>> target = randint(0, 2, (10, 25, 25))
>>> pred = tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> jaccard = JaccardIndex(task="multiclass", num_classes=2)
>>> jaccard(pred, target)
tensor(0.9660)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryJaccardIndex
class torchmetrics.classification.BinaryJaccardIndex(threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]

Calculate the Jaccard index for binary tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • bji (Tensor): A tensor containing the Binary Jaccard Index.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> metric = BinaryJaccardIndex()
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> metric = BinaryJaccardIndex()
>>> metric(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> metric = BinaryJaccardIndex()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/jaccard_index-1.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryJaccardIndex
>>> metric = BinaryJaccardIndex()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/jaccard_index-2.png
MulticlassJaccardIndex
class torchmetrics.classification.MulticlassJaccardIndex(num_classes, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]

Calculate the Jaccard index for multiclass tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mcji (Tensor): A tensor containing the Multi-class Jaccard Index.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassJaccardIndex(num_classes=3)
>>> metric(preds, target)
tensor(0.6667)
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassJaccardIndex(num_classes=3)
>>> metric(preds, target)
tensor(0.6667)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> metric = MulticlassJaccardIndex(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/jaccard_index-3.png
>>> # Example plotting a multiple values per class
>>> from torch import randint
>>> from torchmetrics.classification import MulticlassJaccardIndex
>>> metric = MulticlassJaccardIndex(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/jaccard_index-4.png
MultilabelJaccardIndex
class torchmetrics.classification.MultilabelJaccardIndex(num_labels, threshold=0.5, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]

Calculate the Jaccard index for multilabel tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int tensor or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mlji (Tensor): A tensor containing the Multi-label Jaccard Index loss.

Parameters:
  • num_classes – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/jaccard_index-5.png
>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelJaccardIndex
>>> metric = MultilabelJaccardIndex(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/jaccard_index-6.png

Functional Interface

jaccard_index
torchmetrics.functional.jaccard_index(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True)[source]

Calculate the Jaccard index.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: :rtype: Tensor

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_jaccard_index(), multiclass_jaccard_index() and multilabel_jaccard_index() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import randint, tensor
>>> target = randint(0, 2, (10, 25, 25))
>>> pred = tensor(target)
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
>>> jaccard_index(pred, target, task="multiclass", num_classes=2)
tensor(0.9660)
binary_jaccard_index
torchmetrics.functional.classification.binary_jaccard_index(preds, target, threshold=0.5, ignore_index=None, validate_args=True)[source]

Calculate the Jaccard index for binary tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_jaccard_index
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_jaccard_index(preds, target)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_jaccard_index
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_jaccard_index(preds, target)
tensor(0.5000)
multiclass_jaccard_index
torchmetrics.functional.classification.multiclass_jaccard_index(preds, target, num_classes, average='macro', ignore_index=None, validate_args=True)[source]

Calculate the Jaccard index for multiclass tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_jaccard_index
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_jaccard_index(preds, target, num_classes=3)
tensor(0.6667)
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_jaccard_index
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_jaccard_index(preds, target, num_classes=3)
tensor(0.6667)
multilabel_jaccard_index
torchmetrics.functional.classification.multilabel_jaccard_index(preds, target, num_labels, threshold=0.5, average='macro', ignore_index=None, validate_args=True)[source]

Calculate the Jaccard index for multilabel tasks.

The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:

\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_jaccard_index
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_jaccard_index(preds, target, num_labels=3)
tensor(0.5000)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_jaccard_index
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_jaccard_index(preds, target, num_labels=3)
tensor(0.5000)

Label Ranking Average Precision

Module Interface

class torchmetrics.classification.MultilabelRankingAveragePrecision(num_labels, ignore_index=None, validate_args=True, **kwargs)[source]

Compute label ranking average precision score for multilabel data [1].

The score is the average over each ground truth label assigned to each sample of the ratio of true vs. total labels with lower score. Best score is 1.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mlrap (Tensor): A tensor containing the multilabel ranking average precision.

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example

>>> from torchmetrics.classification import MultilabelRankingAveragePrecision
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> mlrap = MultilabelRankingAveragePrecision(num_labels=5)
>>> mlrap(preds, target)
tensor(0.7744)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelRankingAveragePrecision
>>> metric = MultilabelRankingAveragePrecision(num_labels=3)
>>> metric.update(rand(20, 3), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/label_ranking_average_precision-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelRankingAveragePrecision
>>> metric = MultilabelRankingAveragePrecision(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(20, 3), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/label_ranking_average_precision-2.png

Functional Interface

torchmetrics.functional.classification.multilabel_ranking_average_precision(preds, target, num_labels, ignore_index=None, validate_args=True)[source]

Compute label ranking average precision score for multilabel data [1].

The score is the average over each ground truth label assigned to each sample of the ratio of true vs. total labels with lower score. Best score is 1.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torchmetrics.functional.classification import multilabel_ranking_average_precision
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> multilabel_ranking_average_precision(preds, target, num_labels=5)
tensor(0.7744)

References

[1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US.

Label Ranking Loss

Module Interface

class torchmetrics.classification.MultilabelRankingLoss(num_labels, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the label ranking loss for multilabel data [1].

The score is corresponds to the average number of label pairs that are incorrectly ordered given some predictions weighted by the size of the label set and the number of labels not in the label set. The best score is 0.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mlrl (Tensor): A tensor containing the multilabel ranking loss.

Parameters:
  • preds – Tensor with predictions

  • target – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example

>>> from torchmetrics.classification import MultilabelRankingLoss
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> mlrl = MultilabelRankingLoss(num_labels=5)
>>> mlrl(preds, target)
tensor(0.4167)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelRankingLoss
>>> metric = MultilabelRankingLoss(num_labels=3)
>>> metric.update(rand(20, 3), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/label_ranking_loss-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelRankingLoss
>>> metric = MultilabelRankingLoss(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(20, 3), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/label_ranking_loss-2.png

Functional Interface

torchmetrics.functional.classification.multilabel_ranking_loss(preds, target, num_labels, ignore_index=None, validate_args=True)[source]

Compute the label ranking loss for multilabel data [1].

The score is corresponds to the average number of label pairs that are incorrectly ordered given some predictions weighted by the size of the label set and the number of labels not in the label set. The best score is 0.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example

>>> from torchmetrics.functional.classification import multilabel_ranking_loss
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(10, 5)
>>> target = torch.randint(2, (10, 5))
>>> multilabel_ranking_loss(preds, target, num_labels=5)
tensor(0.4167)

References

[1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US.

Matthews Correlation Coefficient

Module Interface

class torchmetrics.MatthewsCorrCoef(**kwargs)[source]

Calculate Matthews correlation coefficient .

This metric measures the general correlation or quality of a classification.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryMatthewsCorrCoef, MulticlassMatthewsCorrCoef and MultilabelMatthewsCorrCoef for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> matthews_corrcoef = MatthewsCorrCoef(task='binary')
>>> matthews_corrcoef(preds, target)
tensor(0.5774)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryMatthewsCorrCoef
class torchmetrics.classification.BinaryMatthewsCorrCoef(threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]

Calculate Matthews correlation coefficient for binary tasks.

This metric measures the general correlation or quality of a classification.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int tensor or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • bmcc (Tensor): A tensor containing the Binary Matthews Correlation Coefficient.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryMatthewsCorrCoef
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> metric = BinaryMatthewsCorrCoef()
>>> metric(preds, target)
tensor(0.5774)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryMatthewsCorrCoef
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> metric = BinaryMatthewsCorrCoef()
>>> metric(preds, target)
tensor(0.5774)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryMatthewsCorrCoef
>>> metric = BinaryMatthewsCorrCoef()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/matthews_corr_coef-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryMatthewsCorrCoef
>>> metric = BinaryMatthewsCorrCoef()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/matthews_corr_coef-2.png
MulticlassMatthewsCorrCoef
class torchmetrics.classification.MulticlassMatthewsCorrCoef(num_classes, ignore_index=None, validate_args=True, **kwargs)[source]

Calculate Matthews correlation coefficient for multiclass tasks.

This metric measures the general correlation or quality of a classification.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mcmcc (Tensor): A tensor containing the Multi-class Matthews Correlation Coefficient.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassMatthewsCorrCoef
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassMatthewsCorrCoef(num_classes=3)
>>> metric(preds, target)
tensor(0.7000)
Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassMatthewsCorrCoef
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassMatthewsCorrCoef(num_classes=3)
>>> metric(preds, target)
tensor(0.7000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassMatthewsCorrCoef
>>> metric = MulticlassMatthewsCorrCoef(num_classes=3)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/matthews_corr_coef-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassMatthewsCorrCoef
>>> metric = MulticlassMatthewsCorrCoef(num_classes=3)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/matthews_corr_coef-4.png
MultilabelMatthewsCorrCoef
class torchmetrics.classification.MultilabelMatthewsCorrCoef(num_labels, threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]

Calculate Matthews correlation coefficient for multilabel tasks.

This metric measures the general correlation or quality of a classification.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • mlmcc (Tensor): A tensor containing the Multi-label Matthews Correlation Coefficient.

Parameters:
  • num_classes – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelMatthewsCorrCoef
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelMatthewsCorrCoef(num_labels=3)
>>> metric(preds, target)
tensor(0.3333)
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelMatthewsCorrCoef
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelMatthewsCorrCoef(num_labels=3)
>>> metric(preds, target)
tensor(0.3333)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelMatthewsCorrCoef
>>> metric = MultilabelMatthewsCorrCoef(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/matthews_corr_coef-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelMatthewsCorrCoef
>>> metric = MultilabelMatthewsCorrCoef(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/matthews_corr_coef-6.png

Functional Interface

matthews_corrcoef
torchmetrics.functional.matthews_corrcoef(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, ignore_index=None, validate_args=True)[source]

Calculate Matthews correlation coefficient .

This metric measures the general correlation or quality of a classification.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_matthews_corrcoef(), multiclass_matthews_corrcoef() and multilabel_matthews_corrcoef() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2)
tensor(0.5774)
binary_matthews_corrcoef
torchmetrics.functional.classification.binary_matthews_corrcoef(preds, target, threshold=0.5, ignore_index=None, validate_args=True)[source]

Calculate Matthews correlation coefficient for binary tasks.

This metric measures the general correlation or quality of a classification.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_matthews_corrcoef
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0, 1, 0, 0])
>>> binary_matthews_corrcoef(preds, target)
tensor(0.5774)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_matthews_corrcoef
>>> target = tensor([1, 1, 0, 0])
>>> preds = tensor([0.35, 0.85, 0.48, 0.01])
>>> binary_matthews_corrcoef(preds, target)
tensor(0.5774)
multiclass_matthews_corrcoef
torchmetrics.functional.classification.multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index=None, validate_args=True)[source]

Calculate Matthews correlation coefficient for multiclass tasks.

This metric measures the general correlation or quality of a classification.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs – Additional keyword arguments, see Advanced metric settings for more info.

Return type:

Tensor

Example (pred is integer tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_matthews_corrcoef(preds, target, num_classes=3)
tensor(0.7000)
Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_matthews_corrcoef(preds, target, num_classes=3)
tensor(0.7000)
multilabel_matthews_corrcoef
torchmetrics.functional.classification.multilabel_matthews_corrcoef(preds, target, num_labels, threshold=0.5, ignore_index=None, validate_args=True)[source]

Calculate Matthews correlation coefficient for multilabel tasks.

This metric measures the general correlation or quality of a classification.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Additional dimension ... will be flattened into the batch dimension.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_matthews_corrcoef(preds, target, num_labels=3)
tensor(0.3333)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_matthews_corrcoef(preds, target, num_labels=3)
tensor(0.3333)

Precision

Module Interface

class torchmetrics.Precision(**kwargs)[source]

Compute Precision.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryPrecision, MulticlassPrecision and MultilabelPrecision for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> precision = Precision(task="multiclass", average='macro', num_classes=3)
>>> precision(preds, target)
tensor(0.1667)
>>> precision = Precision(task="multiclass", average='micro', num_classes=3)
>>> precision(preds, target)
tensor(0.2500)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryPrecision
class torchmetrics.classification.BinaryPrecision(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Precision for binary tasks.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered a score of 0 is returned.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • bp (Tensor): If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryPrecision
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryPrecision()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryPrecision
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryPrecision()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryPrecision
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryPrecision(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.4000, 0.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryPrecision
>>> metric = BinaryPrecision()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/precision-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryPrecision
>>> metric = BinaryPrecision()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/precision-2.png
MulticlassPrecision
class torchmetrics.classification.MulticlassPrecision(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Precision for multiclass tasks.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...).

As output to forward and compute the metric returns the following output:

  • mcp (Tensor): The returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassPrecision
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassPrecision(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> mcp = MulticlassPrecision(num_classes=3, average=None)
>>> mcp(preds, target)
tensor([1.0000, 0.5000, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassPrecision
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassPrecision(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> mcp = MulticlassPrecision(num_classes=3, average=None)
>>> mcp(preds, target)
tensor([1.0000, 0.5000, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassPrecision
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassPrecision(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.3889, 0.2778])
>>> mcp = MulticlassPrecision(num_classes=3, multidim_average='samplewise', average=None)
>>> mcp(preds, target)
tensor([[0.6667, 0.0000, 0.5000],
        [0.0000, 0.5000, 0.3333]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassPrecision
>>> metric = MulticlassPrecision(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/precision-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassPrecision
>>> metric = MulticlassPrecision(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/precision-4.png
MultilabelPrecision
class torchmetrics.classification.MultilabelPrecision(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Precision for multilabel tasks.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...).

As output to forward and compute the metric returns the following output:

  • mlp (Tensor): The returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelPrecision
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelPrecision(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
>>> mlp = MultilabelPrecision(num_labels=3, average=None)
>>> mlp(preds, target)
tensor([1.0000, 0.0000, 0.5000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelPrecision
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelPrecision(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
>>> mlp = MultilabelPrecision(num_labels=3, average=None)
>>> mlp(preds, target)
tensor([1.0000, 0.0000, 0.5000])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelPrecision
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelPrecision(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.3333, 0.0000])
>>> mlp = MultilabelPrecision(num_labels=3, multidim_average='samplewise', average=None)
>>> mlp(preds, target)
tensor([[0.5000, 0.5000, 0.0000],
        [0.0000, 0.0000, 0.0000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelPrecision
>>> metric = MultilabelPrecision(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/precision-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelPrecision
>>> metric = MultilabelPrecision(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/precision-6.png

Functional Interface

torchmetrics.functional.precision(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute Precision. :rtype: Tensor

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_precision(), multiclass_precision() and multilabel_precision() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> precision(preds, target, task="multiclass", average='macro', num_classes=3)
tensor(0.1667)
>>> precision(preds, target, task="multiclass", average='micro', num_classes=3)
tensor(0.2500)
binary_precision
torchmetrics.functional.classification.binary_precision(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Precision for binary tasks.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_precision
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_precision(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_precision
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_precision(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_precision
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_precision(preds, target, multidim_average='samplewise')
tensor([0.4000, 0.0000])
multiclass_precision
torchmetrics.functional.classification.multiclass_precision(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Precision for multiclass tasks.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_precision
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_precision(preds, target, num_classes=3)
tensor(0.8333)
>>> multiclass_precision(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.5000, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_precision
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_precision(preds, target, num_classes=3)
tensor(0.8333)
>>> multiclass_precision(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.5000, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_precision
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.3889, 0.2778])
>>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.6667, 0.0000, 0.5000],
        [0.0000, 0.5000, 0.3333]])
multilabel_precision
torchmetrics.functional.classification.multilabel_precision(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Precision for multilabel tasks.

\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]

Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_precision
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_precision(preds, target, num_labels=3)
tensor(0.5000)
>>> multilabel_precision(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.5000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_precision
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_precision(preds, target, num_labels=3)
tensor(0.5000)
>>> multilabel_precision(preds, target, num_labels=3, average=None)
tensor([1.0000, 0.0000, 0.5000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_precision
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.3333, 0.0000])
>>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[0.5000, 0.5000, 0.0000],
        [0.0000, 0.0000, 0.0000]])

Precision At Fixed Recall

Module Interface

class torchmetrics.PrecisionAtFixedRecall(**kwargs)[source]

Compute the highest possible recall value given the minimum precision thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryPrecisionAtFixedRecall, MulticlassPrecisionAtFixedRecall and MultilabelPrecisionAtFixedRecall for the specific details of each argument influence and examples.

static __new__(cls, task, min_recall, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryPrecisionAtFixedRecall
class torchmetrics.classification.BinaryPrecisionAtFixedRecall(min_recall, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the highest possible precision value given the minimum recall thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • precision (Tensor): A scalar tensor with the maximum precision for the given recall level

  • threshold (Tensor): A scalar tensor with the corresponding threshold level

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • min_recall (float) – float value specifying minimum recall threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=None)
>>> metric(preds, target)
(tensor(0.6667), tensor(0.5000))
>>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=5)
>>> metric(preds, target)
(tensor(0.6667), tensor(0.5000))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall
>>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5)
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the maximum recall value by default
_images/precision_at_fixed_recall-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall
>>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5)
>>> values = [ ]
>>> for _ in range(10):
...     # we index by 0 such that only the maximum recall value is plotted
...     values.append(metric(rand(10), randint(2,(10,)))[0])
>>> fig_, ax_ = metric.plot(values)
_images/precision_at_fixed_recall-2.png
MulticlassPrecisionAtFixedRecall
class torchmetrics.classification.MulticlassPrecisionAtFixedRecall(num_classes, min_recall, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the highest possible precision value given the minimum recall thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of either 2 tensors or 2 lists containing:

  • precision (Tensor): A 1d tensor of size (n_classes, ) with the maximum precision for the given recall level per class

  • threshold (Tensor): A 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • min_recall (float) – float value specifying minimum recall threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                 [0.05, 0.75, 0.05, 0.05, 0.05],
...                 [0.05, 0.05, 0.75, 0.05, 0.05],
...                 [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=None)
>>> metric(preds, target)  
(tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]),
 tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
>>> mcrafp = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=5)
>>> mcrafp(preds, target)  
(tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]),
 tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall
>>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5)
>>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the maximum recall value by default
_images/precision_at_fixed_recall-3.png
>>> from torch import rand, randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall
>>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5)
>>> values = []
>>> for _ in range(20):
...     # we index by 0 such that only the maximum recall value is plotted
...     values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0])
>>> fig_, ax_ = metric.plot(values)
_images/precision_at_fixed_recall-4.png
MultilabelPrecisionAtFixedRecall
class torchmetrics.classification.MultilabelPrecisionAtFixedRecall(num_labels, min_recall, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the highest possible precision value given the minimum recall thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of either 2 tensors or 2 lists containing:

  • precision (Tensor): A 1d tensor of size (n_classes, ) with the maximum precision for the given recall level per class

  • threshold (Tensor): A 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • min_recall (float) – float value specifying minimum recall threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall
>>> preds = tensor([[0.75, 0.05, 0.35],
...                 [0.45, 0.75, 0.05],
...                 [0.05, 0.55, 0.75],
...                 [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
...                  [0, 0, 0],
...                  [0, 1, 1],
...                  [1, 1, 1]])
>>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5500, 0.3500]))
>>> mlrafp = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=5)
>>> mlrafp(preds, target)
(tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall
>>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5)
>>> metric.update(rand(20, 3), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the maximum recall value by default
_images/precision_at_fixed_recall-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall
>>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5)
>>> values = [ ]
>>> for _ in range(10):
...     # we index by 0 such that only the maximum recall value is plotted
...     values.append(metric(rand(20, 3), randint(2, (20, 3)))[0])
>>> fig_, ax_ = metric.plot(values)
_images/precision_at_fixed_recall-6.png

Functional Interface

binary_precision_at_fixed_recall
torchmetrics.functional.classification.binary_precision_at_fixed_recall(preds, target, min_recall, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the highest possible precision value given the minimum recall thresholds provided for binary tasks.

This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • min_recall (float) – float value specifying minimum recall threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of 2 tensors containing:

  • precision: an scalar tensor with the maximum precision for the given precision level

  • threshold: an scalar tensor with the corresponding threshold level

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import binary_precision_at_fixed_recall
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_precision_at_fixed_recall(preds, target, min_recall=0.5, thresholds=None)
(tensor(0.6667), tensor(0.5000))
>>> binary_precision_at_fixed_recall(preds, target, min_recall=0.5, thresholds=5)
(tensor(0.6667), tensor(0.5000))
multiclass_precision_at_fixed_recall
torchmetrics.functional.classification.multiclass_precision_at_fixed_recall(preds, target, num_classes, min_recall, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the highest possible precision value given the minimum recall thresholds provided for multiclass tasks.

This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • min_recall (float) – float value specifying minimum recall threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • precision: an 1d tensor of size (n_classes, ) with the maximum precision for the given recall level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multiclass_precision_at_fixed_recall
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_precision_at_fixed_recall(  
...     preds, target, num_classes=5, min_recall=0.5, thresholds=None)
(tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]),
 tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
>>> multiclass_precision_at_fixed_recall(  
...     preds, target, num_classes=5, min_recall=0.5, thresholds=5)
(tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]),
 tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
multilabel_precision_at_fixed_recall
torchmetrics.functional.classification.multilabel_precision_at_fixed_recall(preds, target, num_labels, min_recall, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the highest possible precision value given the minimum recall thresholds provided for multilabel tasks.

This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • min_recall (float) – float value specifying minimum recall threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • precision: an 1d tensor of size (n_classes, ) with the maximum precision for the given recall level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multilabel_precision_at_fixed_recall
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> multilabel_precision_at_fixed_recall(preds, target, num_labels=3, min_recall=0.5, thresholds=None)
(tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5500, 0.3500]))
>>> multilabel_precision_at_fixed_recall(preds, target, num_labels=3, min_recall=0.5, thresholds=5)
(tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))

Precision Recall Curve

Module Interface

class torchmetrics.PrecisionRecallCurve(**kwargs)[source]

Compute the precision-recall curve.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve and MultilabelPrecisionRecallCurve for the specific details of each argument influence and examples.

Legacy Example:
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 0])
>>> pr_curve = PrecisionRecallCurve(task="binary")
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
tensor([0.5000, 0.6667, 0.5000, 1.0000, 1.0000])
>>> recall
tensor([1.0000, 1.0000, 0.5000, 0.5000, 0.0000])
>>> thresholds
tensor([0.0000, 0.1000, 0.4000, 0.8000])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = PrecisionRecallCurve(task="multiclass", num_classes=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
 tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
 tensor(0.0500)]
static __new__(cls, task, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryPrecisionRecallCurve
class torchmetrics.classification.BinaryPrecisionRecallCurve(thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the precision-recall curve for binary tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • precision (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned.

  • recall (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned.

  • thresholds (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.classification import BinaryPrecisionRecallCurve
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> bprc = BinaryPrecisionRecallCurve(thresholds=None)
>>> bprc(preds, target)  
(tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
 tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
 tensor([0.0000, 0.5000, 0.7000, 0.8000]))
>>> bprc = BinaryPrecisionRecallCurve(thresholds=5)
>>> bprc(preds, target)  
(tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]),
 tensor([1., 1., 1., 0., 0., 0.]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
plot(curve=None, score=None, ax=None)[source]

Plot a single curve from the metric.

Parameters:
  • curve (Optional[Tuple[Tensor, Tensor, Tensor]]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • score (Union[Tensor, bool, None]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryPrecisionRecallCurve
>>> preds = rand(20)
>>> target = randint(2, (20,))
>>> metric = BinaryPrecisionRecallCurve()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot(score=True)
_images/precision_recall_curve-1.png
MulticlassPrecisionRecallCurve
class torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the precision-recall curve for multiclass tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • precision (Tensor): A 1d tensor of size (n_thresholds+1, ) with precision values

  • recall (Tensor): A 1d tensor of size (n_thresholds+1, ) with recall values

  • thresholds (Tensor): A 1d tensor of size (n_thresholds, ) with increasing threshold values

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to a 1D tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.classification import MulticlassPrecisionRecallCurve
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> mcprc = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=None)
>>> precision, recall, thresholds = mcprc(preds, target)
>>> precision  
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
 tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
 tensor(0.0500)]
>>> mcprc = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=5)
>>> mcprc(preds, target)  
(tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
         [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
         [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
 tensor([[1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
plot(curve=None, score=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • curve (Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]], None]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • score (Union[Tensor, bool, None]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn, randint
>>> from torchmetrics.classification import MulticlassPrecisionRecallCurve
>>> preds = randn(20, 3).softmax(dim=-1)
>>> target = randint(3, (20,))
>>> metric = MulticlassPrecisionRecallCurve(num_classes=3)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot(score=True)
_images/precision_recall_curve-2.png
MultilabelPrecisionRecallCurve
class torchmetrics.classification.MultilabelPrecisionRecallCurve(num_labels, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the precision-recall curve for multilabel tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following a tuple of either 3 tensors or 3 lists containing:

  • precision (Tensor or List): if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned.

  • recall (Tensor or List): if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned.

  • thresholds (Tensor or List): if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds – Tensor with predictions

  • target – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example

>>> from torchmetrics.classification import MultilabelPrecisionRecallCurve
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> mlprc = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=None)
>>> precision, recall, thresholds = mlprc(preds, target)
>>> precision  
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
 tensor([0.7500, 1.0000, 1.0000, 1.0000])]
>>> recall  
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
 tensor([1.0000, 0.6667, 0.3333, 0.0000])]
>>> thresholds  
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])]
>>> mlprc = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=5)
>>> mlprc(preds, target)  
(tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000],
         [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000],
         [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]),
 tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000],
         [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
plot(curve=None, score=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • curve (Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]], None]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • score (Union[Tensor, bool, None]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelPrecisionRecallCurve
>>> preds = rand(20, 3)
>>> target = randint(2, (20,3))
>>> metric = MultilabelPrecisionRecallCurve(num_labels=3)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot(score=True)
_images/precision_recall_curve-3.png

Functional Interface

torchmetrics.functional.precision_recall_curve(preds, target, task, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True)[source]

Compute the precision-recall curve.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_precision_recall_curve(), multiclass_precision_recall_curve() and multilabel_precision_recall_curve() for the specific details of each argument influence and examples.

Return type:

Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]

Legacy Example:
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='binary')
>>> precision
tensor([0.5000, 0.6667, 0.5000, 1.0000, 1.0000])
>>> recall
tensor([1.0000, 1.0000, 0.5000, 0.5000, 0.0000])
>>> thresholds
tensor([0.0000, 0.1000, 0.4000, 0.8000])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='multiclass', num_classes=5)
>>> precision
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
 tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
 tensor([0.0500])]
binary_precision_recall_curve
torchmetrics.functional.classification.binary_precision_recall_curve(preds, target, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the precision-recall curve for binary tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of 3 tensors containing:

  • precision: an 1d tensor of size (n_thresholds+1, ) with precision values

  • recall: an 1d tensor of size (n_thresholds+1, ) with recall values

  • thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import binary_precision_recall_curve
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_precision_recall_curve(preds, target, thresholds=None)  
(tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
 tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
 tensor([0.0000, 0.5000, 0.7000, 0.8000]))
>>> binary_precision_recall_curve(preds, target, thresholds=5)  
(tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]),
 tensor([1., 1., 1., 0., 0., 0.]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
multiclass_precision_recall_curve
torchmetrics.functional.classification.multiclass_precision_recall_curve(preds, target, num_classes, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the precision-recall curve for multiclass tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 3 tensors or 3 lists containing

  • precision: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned.

  • recall: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned.

  • thresholds: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multiclass_precision_recall_curve
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = multiclass_precision_recall_curve(
...    preds, target, num_classes=5, thresholds=None
... )
>>> precision  
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
 tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
 tensor([0.0500])]
>>> multiclass_precision_recall_curve(
...     preds, target, num_classes=5, thresholds=5
... )  
(tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
         [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
         [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
 tensor([[1., 1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
multilabel_precision_recall_curve
torchmetrics.functional.classification.multilabel_precision_recall_curve(preds, target, num_labels, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the precision-recall curve for multilabel tasks.

The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 3 tensors or 3 lists containing

  • precision: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned.

  • recall: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned.

  • thresholds: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multilabel_precision_recall_curve
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> precision, recall, thresholds = multilabel_precision_recall_curve(
...    preds, target, num_labels=3, thresholds=None
... )
>>> precision  
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
 tensor([0.7500, 1.0000, 1.0000, 1.0000])]
>>> recall  
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
 tensor([1.0000, 0.6667, 0.3333, 0.0000])]
>>> thresholds  
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])]
>>> multilabel_precision_recall_curve(
...     preds, target, num_labels=3, thresholds=5
... )  
(tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000],
         [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000],
         [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]),
 tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000],
         [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]),
 tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))

Recall

Module Interface

class torchmetrics.Recall(**kwargs)[source]

Compute Recall.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryRecall, MulticlassRecall and MultilabelRecall for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> recall = Recall(task="multiclass", average='macro', num_classes=3)
>>> recall(preds, target)
tensor(0.3333)
>>> recall = Recall(task="multiclass", average='micro', num_classes=3)
>>> recall(preds, target)
tensor(0.2500)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryRecall
class torchmetrics.classification.BinaryRecall(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Recall for binary tasks.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered a score of 0 is returned.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • br (Tensor): If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryRecall
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryRecall()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryRecall
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryRecall()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryRecall
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryRecall(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.6667, 0.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryRecall
>>> metric = BinaryRecall()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/recall-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryRecall
>>> metric = BinaryRecall()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/recall-2.png
MulticlassRecall
class torchmetrics.classification.MulticlassRecall(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Recall for multiclass tasks.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..) If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • mcr (Tensor): The returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassRecall
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassRecall(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> mcr = MulticlassRecall(num_classes=3, average=None)
>>> mcr(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassRecall
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassRecall(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
>>> mcr = MulticlassRecall(num_classes=3, average=None)
>>> mcr(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassRecall
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassRecall(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.2778])
>>> mcr = MulticlassRecall(num_classes=3, multidim_average='samplewise', average=None)
>>> mcr(preds, target)
tensor([[1.0000, 0.0000, 0.5000],
        [0.0000, 0.3333, 0.5000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassRecall
>>> metric = MulticlassRecall(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/recall-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassRecall
>>> metric = MulticlassRecall(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/recall-4.png
MultilabelRecall
class torchmetrics.classification.MultilabelRecall(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Recall for multilabel tasks.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

As output to forward and compute the metric returns the following output:

  • mlr (Tensor): The returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelRecall
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelRecall(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> mlr = MultilabelRecall(num_labels=3, average=None)
>>> mlr(preds, target)
tensor([1., 0., 1.])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelRecall
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelRecall(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> mlr = MultilabelRecall(num_labels=3, average=None)
>>> mlr(preds, target)
tensor([1., 0., 1.])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelRecall
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelRecall(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.6667, 0.0000])
>>> mlr = MultilabelRecall(num_labels=3, multidim_average='samplewise', average=None)
>>> mlr(preds, target)
tensor([[1., 1., 0.],
        [0., 0., 0.]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelRecall
>>> metric = MultilabelRecall(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/recall-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelRecall
>>> metric = MultilabelRecall(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/recall-6.png

Functional Interface

torchmetrics.functional.recall(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute Recall. :rtype: Tensor

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_recall(), multiclass_recall() and multilabel_recall() for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> recall(preds, target, task="multiclass", average='macro', num_classes=3)
tensor(0.3333)
>>> recall(preds, target, task="multiclass", average='micro', num_classes=3)
tensor(0.2500)
binary_recall
torchmetrics.functional.classification.binary_recall(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Recall for binary tasks.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_recall
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_recall(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_recall
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_recall(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_recall
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_recall(preds, target, multidim_average='samplewise')
tensor([0.6667, 0.0000])
multiclass_recall
torchmetrics.functional.classification.multiclass_recall(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Recall for multiclass tasks.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_recall
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_recall(preds, target, num_classes=3)
tensor(0.8333)
>>> multiclass_recall(preds, target, num_classes=3, average=None)
tensor([0.5000, 1.0000, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_recall
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_recall(preds, target, num_classes=3)
tensor(0.8333)
>>> multiclass_recall(preds, target, num_classes=3, average=None)
tensor([0.5000, 1.0000, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_recall
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.5000, 0.2778])
>>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[1.0000, 0.0000, 0.5000],
        [0.0000, 0.3333, 0.5000]])
multilabel_recall
torchmetrics.functional.classification.multilabel_recall(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Recall for multilabel tasks.

\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]

Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_recall
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_recall(preds, target, num_labels=3)
tensor(0.6667)
>>> multilabel_recall(preds, target, num_labels=3, average=None)
tensor([1., 0., 1.])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_recall
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_recall(preds, target, num_labels=3)
tensor(0.6667)
>>> multilabel_recall(preds, target, num_labels=3, average=None)
tensor([1., 0., 1.])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_recall
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.6667, 0.0000])
>>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[1., 1., 0.],
        [0., 0., 0.]])

Recall At Fixed Precision

Module Interface

class torchmetrics.RecallAtFixedPrecision(**kwargs)[source]

Compute the highest possible recall value given the minimum precision thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryRecallAtFixedPrecision, MulticlassRecallAtFixedPrecision and MultilabelRecallAtFixedPrecision for the specific details of each argument influence and examples.

static __new__(cls, task, min_precision, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryRecallAtFixedPrecision
class torchmetrics.classification.BinaryRecallAtFixedPrecision(min_precision, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the highest possible recall value given the minimum precision thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns the following output:

  • recall (Tensor): A scalar tensor with the maximum recall for the given precision level

  • threshold (Tensor): A scalar tensor with the corresponding threshold level

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryRecallAtFixedPrecision
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=None)
>>> metric(preds, target)
(tensor(1.), tensor(0.5000))
>>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=5)
>>> metric(preds, target)
(tensor(1.), tensor(0.5000))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryRecallAtFixedPrecision
>>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5)
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the maximum recall value by default
_images/recall_at_fixed_precision-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryRecallAtFixedPrecision
>>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5)
>>> values = [ ]
>>> for _ in range(10):
...     # we index by 0 such that only the maximum recall value is plotted
...     values.append(metric(rand(10), randint(2,(10,)))[0])
>>> fig_, ax_ = metric.plot(values)
_images/recall_at_fixed_precision-2.png
MulticlassRecallAtFixedPrecision
class torchmetrics.classification.MulticlassRecallAtFixedPrecision(num_classes, min_precision, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the highest possible recall value given the minimum precision thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of either 2 tensors or 2 lists containing:

  • recall (Tensor): A 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class

  • threshold (Tensor): A 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                 [0.05, 0.75, 0.05, 0.05, 0.05],
...                 [0.05, 0.05, 0.75, 0.05, 0.05],
...                 [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
>>> mcrafp = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=5)
>>> mcrafp(preds, target)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision
>>> metric = MulticlassRecallAtFixedPrecision(num_classes=3, min_precision=0.5)
>>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the maximum recall value by default
_images/recall_at_fixed_precision-3.png
>>> from torch import rand, randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision
>>> metric = MulticlassRecallAtFixedPrecision(num_classes=3, min_precision=0.5)
>>> values = []
>>> for _ in range(20):
...     # we index by 0 such that only the maximum recall value is plotted
...     values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0])
>>> fig_, ax_ = metric.plot(values)
_images/recall_at_fixed_precision-4.png
MultilabelRecallAtFixedPrecision
class torchmetrics.classification.MultilabelRecallAtFixedPrecision(num_labels, min_precision, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the highest possible recall value given the minimum precision thresholds provided.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of either 2 tensors or 2 lists containing:

  • recall (Tensor): A 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class

  • threshold (Tensor): A 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to `None` will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision
>>> preds = tensor([[0.75, 0.05, 0.35],
...                 [0.45, 0.75, 0.05],
...                 [0.05, 0.55, 0.75],
...                 [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
...                  [0, 0, 0],
...                  [0, 1, 1],
...                  [1, 1, 1]])
>>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500]))
>>> mlrafp = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=5)
>>> mlrafp(preds, target)
(tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000]))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision
>>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5)
>>> metric.update(rand(20, 3), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the maximum recall value by default
_images/recall_at_fixed_precision-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision
>>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5)
>>> values = [ ]
>>> for _ in range(10):
...     # we index by 0 such that only the maximum recall value is plotted
...     values.append(metric(rand(20, 3), randint(2, (20, 3)))[0])
>>> fig_, ax_ = metric.plot(values)
_images/recall_at_fixed_precision-6.png

Functional Interface

binary_recall_at_fixed_precision
torchmetrics.functional.classification.binary_recall_at_fixed_precision(preds, target, min_precision, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the highest possible recall value given the minimum precision thresholds provided for binary tasks.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of 2 tensors containing:

  • recall: an scalar tensor with the maximum recall for the given precision level

  • threshold: an scalar tensor with the corresponding threshold level

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import binary_recall_at_fixed_precision
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=None)
(tensor(1.), tensor(0.5000))
>>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=5)
(tensor(1.), tensor(0.5000))
multiclass_recall_at_fixed_precision
torchmetrics.functional.classification.multiclass_recall_at_fixed_precision(preds, target, num_classes, min_precision, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the highest possible recall value given the minimum precision thresholds provided for multiclass tasks.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multiclass_recall_at_fixed_precision
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=None)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
>>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=5)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
multilabel_recall_at_fixed_precision
torchmetrics.functional.classification.multilabel_recall_at_fixed_precision(preds, target, num_labels, min_precision, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the highest possible recall value given the minimum precision thresholds provided for multilabel tasks.

This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • min_precision (float) – float value specifying minimum precision threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d Tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multilabel_recall_at_fixed_precision
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=None)
(tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500]))
>>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=5)
(tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000]))

ROC

Module Interface

class torchmetrics.ROC(**kwargs)[source]

Compute the Receiver Operating Characteristic (ROC).

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryROC, MulticlassROC and MultilabelROC for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> pred = tensor([0.0, 1.0, 2.0, 3.0])
>>> target = tensor([0, 1, 1, 1])
>>> roc = ROC(task="binary")
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000])
>>> pred = tensor([[0.75, 0.05, 0.05, 0.05],
...                [0.05, 0.75, 0.05, 0.05],
...                [0.05, 0.05, 0.75, 0.05],
...                [0.05, 0.05, 0.05, 0.75]])
>>> target = tensor([0, 1, 3, 2])
>>> roc = ROC(task="multiclass", num_classes=4)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])]
>>> thresholds  
[tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500])]
>>> pred = tensor([[0.8191, 0.3680, 0.1138],
...                [0.3584, 0.7576, 0.1183],
...                [0.2286, 0.3468, 0.1338],
...                [0.8603, 0.0745, 0.1837]])
>>> target = tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> roc = ROC(task='multilabel', num_labels=3)
>>> fpr, tpr, thresholds = roc(pred, target)
>>> fpr
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
 tensor([0., 0., 0., 1., 1.]),
 tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]),
 tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
 tensor([0., 1., 1., 1., 1.])]
>>> thresholds
[tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]),
 tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]),
 tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])]
static __new__(cls, task, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryROC
class torchmetrics.classification.BinaryROC(thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the Receiver Operating Characteristic (ROC) for binary tasks.

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of 3 tensors containing:

  • fpr (Tensor): A 1d tensor of size (n_thresholds+1, ) with false positive rate values

  • tpr (Tensor): A 1d tensor of size (n_thresholds+1, ) with true positive rate values

  • thresholds (Tensor): A 1d tensor of size (n_thresholds, ) with decreasing threshold values

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Note

The outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.

Parameters:
  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import BinaryROC
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> metric = BinaryROC(thresholds=None)
>>> metric(preds, target)  
(tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]),
 tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]),
 tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000]))
>>> broc = BinaryROC(thresholds=5)
>>> broc(preds, target)  
(tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]),
 tensor([0., 0., 1., 1., 1.]),
 tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
plot(curve=None, score=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • curve (Optional[Tuple[Tensor, Tensor, Tensor]]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • score (Union[Tensor, bool, None]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> from torchmetrics.classification import BinaryROC
>>> preds = rand(20)
>>> target = randint(2, (20,))
>>> metric = BinaryROC()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot(score=True)
_images/roc-1.png
MulticlassROC
class torchmetrics.classification.MulticlassROC(num_classes, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the Receiver Operating Characteristic (ROC) for binary tasks.

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (Tensor): An int tensor of shape (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of either 3 tensors or 3 lists containing

  • fpr (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with false positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned.

  • tpr (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with true positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned.

  • thresholds (Tensor): if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with decreasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Note

Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassROC
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                 [0.05, 0.75, 0.05, 0.05, 0.05],
...                 [0.05, 0.05, 0.75, 0.05, 0.05],
...                 [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassROC(num_classes=5, thresholds=None)
>>> fpr, tpr, thresholds = metric(preds, target)
>>> fpr  
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]),
 tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])]
>>> thresholds  
[tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])]
>>> mcroc = MulticlassROC(num_classes=5, thresholds=5)
>>> mcroc(preds, target)  
(tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.3333, 0.3333, 0.3333, 1.0000],
         [0.0000, 0.3333, 0.3333, 0.3333, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
 tensor([[0., 1., 1., 1., 1.],
         [0., 1., 1., 1., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.]]),
 tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
plot(curve=None, score=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • curve (Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]], None]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • score (Union[Tensor, bool, None]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn, randint
>>> from torchmetrics.classification import MulticlassROC
>>> preds = randn(20, 3).softmax(dim=-1)
>>> target = randint(3, (20,))
>>> metric = MulticlassROC(num_classes=3)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot(score=True)
_images/roc-2.png
MultilabelROC
class torchmetrics.classification.MultilabelROC(num_labels, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the Receiver Operating Characteristic (ROC) for binary tasks.

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (Tensor): An int tensor of shape (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Note

Additional dimension ... will be flattened into the batch dimension.

As output to forward and compute the metric returns a tuple of either 3 tensors or 3 lists containing

  • fpr (Tensor): if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with false positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned.

  • tpr (Tensor): if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with true positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned.

  • thresholds (Tensor): if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds, ) with decreasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.

Note

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Note

The outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelROC
>>> preds = tensor([[0.75, 0.05, 0.35],
...                 [0.45, 0.75, 0.05],
...                 [0.05, 0.55, 0.75],
...                 [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
...                  [0, 0, 0],
...                  [0, 1, 1],
...                  [1, 1, 1]])
>>> metric = MultilabelROC(num_labels=3, thresholds=None)
>>> fpr, tpr, thresholds = metric(preds, target)
>>> fpr  
[tensor([0.0000, 0.0000, 0.5000, 1.0000]),
 tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]),
 tensor([0., 0., 0., 1.])]
>>> tpr  
[tensor([0.0000, 0.5000, 0.5000, 1.0000]),
 tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]),
 tensor([0.0000, 0.3333, 0.6667, 1.0000])]
>>> thresholds  
[tensor([1.0000, 0.7500, 0.4500, 0.0500]),
 tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]),
 tensor([1.0000, 0.7500, 0.3500, 0.0500])]
>>> mlroc = MultilabelROC(num_labels=3, thresholds=5)
>>> mlroc(preds, target)  
(tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 0.5000, 0.5000, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
 tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000],
         [0.0000, 0.0000, 1.0000, 1.0000, 1.0000],
         [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]),
 tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
plot(curve=None, score=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • curve (Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]], None]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • score (Union[Tensor, bool, None]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> from torchmetrics.classification import MultilabelROC
>>> preds = rand(20, 3)
>>> target = randint(2, (20,3))
>>> metric = MultilabelROC(num_labels=3)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot(score=True)
_images/roc-3.png

Functional Interface

torchmetrics.functional.roc(preds, target, task, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True)[source]

Compute the Receiver Operating Characteristic (ROC).

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_roc(), multiclass_roc() and multilabel_roc() for the specific details of each argument influence and examples.

Return type:

Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]

Legacy Example:
>>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0])
>>> target = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = roc(pred, target, task='binary')
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05],
...                      [0.05, 0.75, 0.05, 0.05],
...                      [0.05, 0.05, 0.75, 0.05],
...                      [0.05, 0.05, 0.05, 0.75]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> fpr, tpr, thresholds = roc(pred, target, task='multiclass', num_classes=4)
>>> fpr
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])]
>>> thresholds
[tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500])]
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138],
...                      [0.3584, 0.7576, 0.1183],
...                      [0.2286, 0.3468, 0.1338],
...                      [0.8603, 0.0745, 0.1837]])
>>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]])
>>> fpr, tpr, thresholds = roc(pred, target, task='multilabel', num_labels=3)
>>> fpr
[tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]),
 tensor([0., 0., 0., 1., 1.]),
 tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])]
>>> tpr
[tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])]
>>> thresholds
[tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]),
 tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]),
 tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])]
binary_roc
torchmetrics.functional.classification.binary_roc(preds, target, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the Receiver Operating Characteristic (ROC) for binary tasks.

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of 3 tensors containing:

  • fpr: an 1d tensor of size (n_thresholds+1, ) with false positive rate values

  • tpr: an 1d tensor of size (n_thresholds+1, ) with true positive rate values

  • thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import binary_roc
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
>>> target = torch.tensor([0, 1, 1, 0])
>>> binary_roc(preds, target, thresholds=None)  
(tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]),
 tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]),
 tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000]))
>>> binary_roc(preds, target, thresholds=5)  
(tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]),
 tensor([0., 0., 1., 1., 1.]),
 tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
multiclass_roc
torchmetrics.functional.classification.multiclass_roc(preds, target, num_classes, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the Receiver Operating Characteristic (ROC) for multiclass tasks.

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 3 tensors or 3 lists containing

  • fpr: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with false positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned.

  • tpr: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with true positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned.

  • thresholds: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with decreasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multiclass_roc
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> fpr, tpr, thresholds = multiclass_roc(
...    preds, target, num_classes=5, thresholds=None
... )
>>> fpr  
[tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]),
 tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])]
>>> tpr
[tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])]
>>> thresholds  
[tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]),
 tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])]
>>> multiclass_roc(
...     preds, target, num_classes=5, thresholds=5
... )  
(tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
         [0.0000, 0.3333, 0.3333, 0.3333, 1.0000],
         [0.0000, 0.3333, 0.3333, 0.3333, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
 tensor([[0., 1., 1., 1., 1.],
         [0., 1., 1., 1., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.]]),
 tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
multilabel_roc
torchmetrics.functional.classification.multilabel_roc(preds, target, num_labels, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the Receiver Operating Characteristic (ROC) for multilabel tasks.

The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 3 tensors or 3 lists containing

  • fpr: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with false positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned.

  • tpr: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with true positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned.

  • thresholds: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds, ) with decreasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multilabel_roc
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> fpr, tpr, thresholds = multilabel_roc(
...    preds, target, num_labels=3, thresholds=None
... )
>>> fpr  
[tensor([0.0000, 0.0000, 0.5000, 1.0000]),
 tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]),
 tensor([0., 0., 0., 1.])]
>>> tpr  
[tensor([0.0000, 0.5000, 0.5000, 1.0000]),
 tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]),
 tensor([0.0000, 0.3333, 0.6667, 1.0000])]
>>> thresholds  
[tensor([1.0000, 0.7500, 0.4500, 0.0500]),
 tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]),
 tensor([1.0000, 0.7500, 0.3500, 0.0500])]
>>> multilabel_roc(
...     preds, target, num_labels=3, thresholds=5
... )  
(tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000],
         [0.0000, 0.5000, 0.5000, 0.5000, 1.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]),
 tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000],
         [0.0000, 0.0000, 1.0000, 1.0000, 1.0000],
         [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]),
 tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))

Specificity

Module Interface

class torchmetrics.Specificity(**kwargs)[source]

Compute Specificity.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinarySpecificity, MulticlassSpecificity and MultilabelSpecificity for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> specificity = Specificity(task="multiclass", average='macro', num_classes=3)
>>> specificity(preds, target)
tensor(0.6111)
>>> specificity = Specificity(task="multiclass", average='micro', num_classes=3)
>>> specificity(preds, target)
tensor(0.6250)
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinarySpecificity
class torchmetrics.classification.BinarySpecificity(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Specificity for binary tasks.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TN} + \text{FP} \neq 0\). If this case is encountered a score of 0 is returned.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • bs (Tensor): If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinarySpecificity
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinarySpecificity()
>>> metric(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.classification import BinarySpecificity
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinarySpecificity()
>>> metric(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.classification import BinarySpecificity
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinarySpecificity(multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.0000, 0.3333])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinarySpecificity
>>> metric = BinarySpecificity()
>>> metric.update(rand(10), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/specificity-1.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinarySpecificity
>>> metric = BinarySpecificity()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(rand(10), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
_images/specificity-2.png
MulticlassSpecificity
class torchmetrics.classification.MulticlassSpecificity(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Specificity for multiclass tasks.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TN} + \text{FP} \neq 0\). If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • mcs (Tensor): The returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global:

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise:

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassSpecificity
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassSpecificity(num_classes=3)
>>> metric(preds, target)
tensor(0.8889)
>>> mcs = MulticlassSpecificity(num_classes=3, average=None)
>>> mcs(preds, target)
tensor([1.0000, 0.6667, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassSpecificity
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassSpecificity(num_classes=3)
>>> metric(preds, target)
tensor(0.8889)
>>> mcs = MulticlassSpecificity(num_classes=3, average=None)
>>> mcs(preds, target)
tensor([1.0000, 0.6667, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassSpecificity
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.7500, 0.6556])
>>> mcs = MulticlassSpecificity(num_classes=3, multidim_average='samplewise', average=None)
>>> mcs(preds, target)
tensor([[0.7500, 0.7500, 0.7500],
        [0.8000, 0.6667, 0.5000]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randint
>>> # Example plotting a single value per class
>>> from torchmetrics.classification import MulticlassSpecificity
>>> metric = MulticlassSpecificity(num_classes=3, average=None)
>>> metric.update(randint(3, (20,)), randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/specificity-3.png
>>> from torch import randint
>>> # Example plotting a multiple values per class
>>> from torchmetrics.classification import MulticlassSpecificity
>>> metric = MulticlassSpecificity(num_classes=3, average=None)
>>> values = []
>>> for _ in range(20):
...     values.append(metric(randint(3, (20,)), randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/specificity-4.png
MultilabelSpecificity
class torchmetrics.classification.MultilabelSpecificity(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute Specificity for multilabel tasks.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TN} + \text{FP} \neq 0\). If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

As output to forward and compute the metric returns the following output:

  • mls (Tensor): The returned shape depends on the average and multidim_average arguments:

    • If multidim_average is set to global

      • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

      • If average=None/'none', the shape will be (C,)

    • If multidim_average is set to samplewise

      • If average='micro'/'macro'/'weighted', the shape will be (N,)

      • If average=None/'none', the shape will be (N, C)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelSpecificity
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelSpecificity(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> mls = MultilabelSpecificity(num_labels=3, average=None)
>>> mls(preds, target)
tensor([1., 1., 0.])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelSpecificity
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelSpecificity(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
>>> mls = MultilabelSpecificity(num_labels=3, average=None)
>>> mls(preds, target)
tensor([1., 1., 0.])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelSpecificity
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99],  [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelSpecificity(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.0000, 0.3333])
>>> mls = MultilabelSpecificity(num_labels=3, multidim_average='samplewise', average=None)
>>> mls(preds, target)
tensor([[0., 0., 0.],
        [0., 0., 1.]])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import MultilabelSpecificity
>>> metric = MultilabelSpecificity(num_labels=3)
>>> metric.update(randint(2, (20, 3)), randint(2, (20, 3)))
>>> fig_, ax_ = metric.plot()
_images/specificity-5.png
>>> from torch import rand, randint
>>> # Example plotting multiple values
>>> from torchmetrics.classification import MultilabelSpecificity
>>> metric = MultilabelSpecificity(num_labels=3)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(randint(2, (20, 3)), randint(2, (20, 3))))
>>> fig_, ax_ = metric.plot(values)
_images/specificity-6.png

Functional Interface

torchmetrics.functional.specificity(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute Specificity. :rtype: Tensor

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_specificity(), multiclass_specificity() and multilabel_specificity() for the specific details of each argument influence and examples.

LegacyExample:
>>> from torch import tensor
>>> preds  = tensor([2, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> specificity(preds, target, task="multiclass", average='macro', num_classes=3)
tensor(0.6111)
>>> specificity(preds, target, task="multiclass", average='micro', num_classes=3)
tensor(0.6250)
binary_specificity
torchmetrics.functional.classification.binary_specificity(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Specificity for binary tasks.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

If multidim_average is set to global, the metric returns a scalar value. If multidim_average is set to samplewise, the metric returns (N,) vector consisting of a scalar value per sample.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_specificity
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_specificity(preds, target)
tensor(0.6667)
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_specificity
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_specificity(preds, target)
tensor(0.6667)
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_specificity
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_specificity(preds, target, multidim_average='samplewise')
tensor([0.0000, 0.3333])
multiclass_specificity
torchmetrics.functional.classification.multiclass_specificity(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Specificity for multiclass tasks.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_specificity
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_specificity(preds, target, num_classes=3)
tensor(0.8889)
>>> multiclass_specificity(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.6667, 1.0000])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_specificity
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_specificity(preds, target, num_classes=3)
tensor(0.8889)
>>> multiclass_specificity(preds, target, num_classes=3, average=None)
tensor([1.0000, 0.6667, 1.0000])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_specificity
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise')
tensor([0.7500, 0.6556])
>>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[0.7500, 0.7500, 0.7500],
        [0.8000, 0.6667, 0.5000]])
multilabel_specificity
torchmetrics.functional.classification.multilabel_specificity(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute Specificity for multilabel tasks.

\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]

Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the output will be a scalar tensor

    • If average=None/'none', the shape will be (C,)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N,)

    • If average=None/'none', the shape will be (N, C)

Return type:

The returned shape depends on the average and multidim_average arguments

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_specificity
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_specificity(preds, target, num_labels=3)
tensor(0.6667)
>>> multilabel_specificity(preds, target, num_labels=3, average=None)
tensor([1., 1., 0.])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_specificity
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_specificity(preds, target, num_labels=3)
tensor(0.6667)
>>> multilabel_specificity(preds, target, num_labels=3, average=None)
tensor([1., 1., 0.])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_specificity
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise')
tensor([0.0000, 0.3333])
>>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[0., 0., 0.],
        [0., 0., 1.]])

Specificity At Sensitivity

Module Interface

class torchmetrics.SpecificityAtSensitivity(**kwargs)[source]

Compute the higest possible specificity value given the minimum sensitivity thresholds provided.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinarySpecificityAtSensitivity, MulticlassSpecificityAtSensitivity and MultilabelSpecificityAtSensitivity for the specific details of each argument influence and examples.

static __new__(cls, task, min_sensitivity, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinarySpecificityAtSensitivity
class torchmetrics.classification.BinarySpecificityAtSensitivity(min_sensitivity, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the higest possible specificity value given the minimum sensitivity thresholds provided.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • min_sensitivity (float) – float value specifying minimum sensitivity threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

a tuple of 2 tensors containing:

  • specificity: an scalar tensor with the maximum specificity for the given sensitivity level

  • threshold: an scalar tensor with the corresponding threshold level

Return type:

(tuple)

Example

>>> from torchmetrics.classification import BinarySpecificityAtSensitivity
>>> from torch import tensor
>>> preds = tensor([0, 0.5, 0.4, 0.1])
>>> target = tensor([0, 1, 1, 1])
>>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=None)
>>> metric(preds, target)
(tensor(1.), tensor(0.4000))
>>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=5)
>>> metric(preds, target)
(tensor(1.), tensor(0.2500))
MulticlassSpecificityAtSensitivity
class torchmetrics.classification.MulticlassSpecificityAtSensitivity(num_classes, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the higest possible specificity value given the minimum sensitivity thresholds provided.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • min_sensitivity (float) – float value specifying minimum sensitivity threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given

    sensitivity level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.classification import MulticlassSpecificityAtSensitivity
>>> from torch import tensor
>>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                 [0.05, 0.75, 0.05, 0.05, 0.05],
...                 [0.05, 0.05, 0.75, 0.05, 0.05],
...                 [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = tensor([0, 1, 3, 2])
>>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
>>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=5)
>>> metric(preds, target)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
MultilabelSpecificityAtSensitivity
class torchmetrics.classification.MultilabelSpecificityAtSensitivity(num_labels, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]

Compute the higest possible specificity value given the minimum sensitivity thresholds provided.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • min_sensitivity (float) – float value specifying minimum sensitivity threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given

    sensitivity level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.classification import MultilabelSpecificityAtSensitivity
>>> from torch import tensor
>>> preds = tensor([[0.75, 0.05, 0.35],
...                 [0.45, 0.75, 0.05],
...                 [0.05, 0.55, 0.75],
...                 [0.05, 0.65, 0.05]])
>>> target = tensor([[1, 0, 1],
...                  [0, 0, 0],
...                  [0, 1, 1],
...                  [1, 1, 1]])
>>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=None)
>>> metric(preds, target)
(tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500]))
>>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=5)
>>> metric(preds, target)
(tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))

Functional Interface

binary_specificity_at_sensitivity
torchmetrics.functional.classification.binary_specificity_at_sensitivity(preds, target, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the higest possible specificity value given the minimum sensitivity levels provided for binary tasks.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

Accepts the following input tensors:

  • preds (float tensor): (N, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • min_sensitivity (float) – float value specifying minimum sensitivity threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of 2 tensors containing:

  • specificity: a scalar tensor with the maximum specificity for the given sensitivity level

  • threshold: a scalar tensor with the corresponding threshold level

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import binary_specificity_at_sensitivity
>>> preds = torch.tensor([0, 0.5, 0.4, 0.1])
>>> target = torch.tensor([0, 1, 1, 1])
>>> binary_specificity_at_sensitivity(preds, target, min_sensitivity=0.5, thresholds=None)
(tensor(1.), tensor(0.4000))
>>> binary_specificity_at_sensitivity(preds, target, min_sensitivity=0.5, thresholds=5)
(tensor(1.), tensor(0.2500))
multiclass_specificity_at_sensitivity
torchmetrics.functional.classification.multiclass_specificity_at_sensitivity(preds, target, num_classes, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the higest possible specificity value given the minimum sensitivity level provided for multiclass tasks.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.

  • target (int tensor): (N, ...). Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • min_sensitivity (float) – float value specifying minimum sensitivity threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multiclass_specificity_at_sensitivity
>>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
...                       [0.05, 0.75, 0.05, 0.05, 0.05],
...                       [0.05, 0.05, 0.75, 0.05, 0.05],
...                       [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_specificity_at_sensitivity(preds, target, num_classes=5, min_sensitivity=0.5, thresholds=None)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06]))
>>> multiclass_specificity_at_sensitivity(preds, target, num_classes=5, min_sensitivity=0.5, thresholds=5)
(tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
multilabel_specificity_at_sensitivity
torchmetrics.functional.classification.multilabel_specificity_at_sensitivity(preds, target, num_labels, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True)[source]

Compute the higest possible specificity value given the minimum sensitivity level provided for multilabel tasks.

This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.

Accepts the following input tensors:

  • preds (float tensor): (N, C, ...). Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

  • target (int tensor): (N, C, ...). Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).

Additional dimension ... will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • min_sensitivity (float) – float value specifying minimum sensitivity threshold.

  • thresholds (Union[int, List[float], Tensor, None]) –

    Can be one of:

    • If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.

    • If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.

    • If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation

    • If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Returns:

a tuple of either 2 tensors or 2 lists containing

  • specificity: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision

    level per class

  • thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class

Return type:

(tuple)

Example

>>> from torchmetrics.functional.classification import multilabel_specificity_at_sensitivity
>>> preds = torch.tensor([[0.75, 0.05, 0.35],
...                       [0.45, 0.75, 0.05],
...                       [0.05, 0.55, 0.75],
...                       [0.05, 0.65, 0.05]])
>>> target = torch.tensor([[1, 0, 1],
...                        [0, 0, 0],
...                        [0, 1, 1],
...                        [1, 1, 1]])
>>> multilabel_specificity_at_sensitivity(preds, target, num_labels=3, min_sensitivity=0.5, thresholds=None)
(tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500]))
>>> multilabel_specificity_at_sensitivity(preds, target, num_labels=3, min_sensitivity=0.5, thresholds=5)
(tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))

Stat Scores

Module Interface

class torchmetrics.StatScores(**kwargs)[source]

Compute the number of true positives, false positives, true negatives, false negatives and the support.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of BinaryStatScores, MulticlassStatScores and MultilabelStatScores for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([1, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> stat_scores = StatScores(task="multiclass", num_classes=3, average='micro')
>>> stat_scores(preds, target)
tensor([2, 2, 6, 2, 4])
>>> stat_scores = StatScores(task="multiclass", num_classes=3, average=None)
>>> stat_scores(preds, target)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryStatScores
class torchmetrics.classification.BinaryStatScores(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute true positives, false positives, true negatives, false negatives and the support for binary tasks.

Related to Type I and Type II errors.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • bss (Tensor): A tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the multidim_average parameter:

  • If multidim_average is set to global, the shape will be (5,)

  • If multidim_average is set to samplewise, the shape will be (N, 5)

Parameters:
  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import BinaryStatScores
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> metric = BinaryStatScores()
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> metric = BinaryStatScores()
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
Example (multidim tensors):
>>> from torchmetrics.classification import BinaryStatScores
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = BinaryStatScores(multidim_average='samplewise')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
MulticlassStatScores
class torchmetrics.classification.MulticlassStatScores(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Computes true positives, false positives, true negatives, false negatives and the support for multiclass tasks.

Related to Type I and Type II errors.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int tensor of shape (N, ...) or float tensor of shape (N, C, ..). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (Tensor): An int tensor of shape (N, ...)

As output to forward and compute the metric returns the following output:

  • mcss (Tensor): A tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global

  • If average='micro'/'macro'/'weighted', the shape will be (5,)

  • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise

  • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

  • If average=None/'none', the shape will be (N, C, 5)

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> mcss = MulticlassStatScores(num_classes=3, average=None)
>>> mcss(preds, target)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> metric = MulticlassStatScores(num_classes=3, average='micro')
>>> metric(preds, target)
tensor([3, 1, 7, 1, 4])
>>> mcss = MulticlassStatScores(num_classes=3, average=None)
>>> mcss(preds, target)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassStatScores
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro')
>>> metric(preds, target)
tensor([[3, 3, 9, 3, 6],
        [2, 4, 8, 4, 6]])
>>> mcss = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average=None)
>>> mcss(preds, target)
tensor([[[2, 1, 3, 0, 2],
         [0, 1, 3, 2, 2],
         [1, 1, 3, 1, 2]],
        [[0, 1, 4, 1, 1],
         [1, 1, 2, 2, 3],
         [1, 2, 2, 1, 2]]])
MultilabelStatScores
class torchmetrics.classification.MultilabelStatScores(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]

Compute true positives, false positives, true negatives, false negatives and the support for multilabel tasks.

Related to Type I and Type II errors.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): An int or float tensor of shape (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): An int tensor of shape (N, C, ...)

As output to forward and compute the metric returns the following output:

  • mlss (Tensor): A tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global

  • If average='micro'/'macro'/'weighted', the shape will be (5,)

  • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise

  • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

  • If average=None/'none', the shape will be (N, C, 5)

Parameters:
  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> mlss = MultilabelStatScores(num_labels=3, average=None)
>>> mlss(preds, target)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelStatScores(num_labels=3, average='micro')
>>> metric(preds, target)
tensor([2, 1, 2, 1, 3])
>>> mlss = MultilabelStatScores(num_labels=3, average=None)
>>> mlss(preds, target)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelStatScores
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average='micro')
>>> metric(preds, target)
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
>>> mlss = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average=None)
>>> mlss(preds, target)
tensor([[[1, 1, 0, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 1, 0, 1, 1]],
        [[0, 0, 0, 2, 2],
         [0, 2, 0, 0, 0],
         [0, 0, 1, 1, 1]]])

Functional Interface

stat_scores
torchmetrics.functional.stat_scores(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]

Compute the number of true positives, false positives, true negatives, false negatives and the support.

This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or multilabel. See the documentation of binary_stat_scores(), multiclass_stat_scores() and multilabel_stat_scores() for the specific details of each argument influence and examples.

Return type:

Tensor

Legacy Example:
>>> from torch import tensor
>>> preds  = tensor([1, 0, 2, 1])
>>> target = tensor([1, 1, 2, 0])
>>> stat_scores(preds, target, task='multiclass', num_classes=3, average='micro')
tensor([2, 2, 6, 2, 4])
>>> stat_scores(preds, target, task='multiclass', num_classes=3, average=None)
tensor([[0, 1, 2, 1, 1],
        [1, 1, 1, 1, 2],
        [1, 0, 3, 0, 1]])
binary_stat_scores
torchmetrics.functional.classification.binary_stat_scores(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute the true positives, false positives, true negatives, false negatives, support for binary tasks.

Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • threshold (float) – Threshold for transforming probability to binary {0,1} predictions

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on the multidim_average parameter:

  • If multidim_average is set to global, the shape will be (5,)

  • If multidim_average is set to samplewise, the shape will be (N, 5)

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0, 0, 1, 1, 0, 1])
>>> binary_stat_scores(preds, target)
tensor([2, 1, 2, 1, 3])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = tensor([0, 1, 0, 1, 0, 1])
>>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92])
>>> binary_stat_scores(preds, target)
tensor([2, 1, 2, 1, 3])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_stat_scores
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> binary_stat_scores(preds, target, multidim_average='samplewise')
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
multiclass_stat_scores
torchmetrics.functional.classification.multiclass_stat_scores(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute the true positives, false positives, true negatives, false negatives and support for multiclass tasks.

Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds: (N, ...) (int tensor) or (N, C, ..) (float tensor). If preds is a floating point we apply torch.argmax along the C dimension to automatically convert probabilities/logits into an int tensor.

  • target (int tensor): (N, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_classes (int) – Integer specifing the number of classes

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the shape will be (5,)

    • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

    • If average=None/'none', the shape will be (N, C, 5)

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([2, 1, 0, 1])
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = tensor([2, 1, 0, 0])
>>> preds = tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
>>> multiclass_stat_scores(preds, target, num_classes=3, average='micro')
tensor([3, 1, 7, 1, 4])
>>> multiclass_stat_scores(preds, target, num_classes=3, average=None)
tensor([[1, 0, 2, 1, 2],
        [1, 1, 2, 0, 1],
        [1, 0, 3, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_stat_scores
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]])
>>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro')
tensor([[3, 3, 9, 3, 6],
        [2, 4, 8, 4, 6]])
>>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None)
tensor([[[2, 1, 3, 0, 2],
         [0, 1, 3, 2, 2],
         [1, 1, 3, 1, 2]],
        [[0, 1, 4, 1, 1],
         [1, 1, 2, 2, 3],
         [1, 2, 2, 1, 2]]])
multilabel_stat_scores
torchmetrics.functional.classification.multilabel_stat_scores(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]

Compute the true positives, false positives, true negatives, false negatives and support for multilabel tasks.

Related to Type I and Type II errors.

Accepts the following input tensors:

  • preds (int or float tensor): (N, C, ...). If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in threshold.

  • target (int tensor): (N, C, ...)

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with true labels

  • num_labels (int) – Integer specifing the number of labels

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • average (Optional[Literal['micro', 'macro', 'weighted', 'none']]) –

    Defines the reduction that is applied over labels. Should be one of the following:

    • micro: Sum statistics over all labels

    • macro: Calculate statistics for each label and average them

    • weighted: calculates statistics for each label and computes weighted average using their support

    • "none" or None: calculates statistic for each label and applies no reduction

  • multidim_average (Literal['global', 'samplewise']) –

    Defines how additionally dimensions ... should be handled. Should be one of the following:

    • global: Additional dimensions are flatted along the batch dimension

    • samplewise: Statistic will be calculated independently for each sample on the N axis. The statistics in this case are calculated over the additional dimensions.

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness. Set to False for faster computations.

Return type:

Tensor

Returns:

The metric returns a tensor of shape (..., 5), where the last dimension corresponds to [tp, fp, tn, fn, sup] (sup stands for support and equals tp + fn). The shape depends on average and multidim_average parameters:

  • If multidim_average is set to global:

    • If average='micro'/'macro'/'weighted', the shape will be (5,)

    • If average=None/'none', the shape will be (C, 5)

  • If multidim_average is set to samplewise:

    • If average='micro'/'macro'/'weighted', the shape will be (N, 5)

    • If average=None/'none', the shape will be (N, C, 5)

Example (preds is int tensor):
>>> from torch import tensor
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = tensor([[0, 1, 0], [1, 0, 1]])
>>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> multilabel_stat_scores(preds, target, num_labels=3, average='micro')
tensor([2, 1, 2, 1, 3])
>>> multilabel_stat_scores(preds, target, num_labels=3, average=None)
tensor([[1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1],
        [1, 1, 0, 0, 1]])
Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_stat_scores
>>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]])
>>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]],
...                 [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]])
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average='micro')
tensor([[2, 3, 0, 1, 3],
        [0, 2, 1, 3, 3]])
>>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None)
tensor([[[1, 1, 0, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 1, 0, 1, 1]],
        [[0, 0, 0, 2, 2],
         [0, 2, 0, 0, 0],
         [0, 0, 1, 1, 1]]])

Complete Intersection Over Union (cIoU)

Module Interface

class torchmetrics.detection.ciou.CompleteIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]

Computes Complete Intersection Over Union (CIoU).

As input to forward and update the metric accepts the following input:

  • preds (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes detection boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed detection classes for the boxes.

  • target (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes ground truth boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed detection classes for the boxes.

As output of forward and compute the metric returns the following output:

  • ciou_dict: A dictionary containing the following key-values:

    • ciou: (Tensor) with overall ciou value over all classes and samples.

    • ciou/cl_{cl}: (Tensor), if argument class_metrics=True

Parameters:
  • box_format (str) – Input format of given boxes. Supported formats are [`xyxy`, `xywh`, `cxcywh`].

  • iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • class_metrics (bool) – Option to enable per-class metrics for IoU. Has a performance impact.

  • respect_labels (bool) – Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou between all pairs of boxes.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.detection import CompleteIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = CompleteIntersectionOverUnion()
>>> metric(preds, target)
{'ciou': tensor(0.8611)}
Raises:

ModuleNotFoundError – If torchvision is not installed with version 0.13.0 or newer.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting single value
>>> import torch
>>> from torchmetrics.detection import CompleteIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = CompleteIntersectionOverUnion()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/complete_intersection_over_union-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.detection import CompleteIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = lambda : [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = CompleteIntersectionOverUnion()
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds, target()))
>>> fig_, ax_ = metric.plot(vals)
_images/complete_intersection_over_union-2.png

Functional Interface

torchmetrics.functional.detection.ciou.complete_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]

Compute Complete Intersection over Union (CIOU) between two sets of boxes.

Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.

Parameters:
  • preds (Tensor) – The input tensor containing the predicted bounding boxes.

  • target (Tensor) – The tensor containing the ground truth.

  • iou_threshold (Optional[float]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • replacement_val (float) – Value to replace values under the threshold with.

  • aggregate (bool) – Return the average value instead of the full matrix of values

Return type:

Tensor

Example::

By default iou is aggregated across all box pairs e.g. mean along the diagonal of the IoU matrix:

>>> import torch
>>> from torchmetrics.functional.detection import complete_intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> complete_intersection_over_union(preds, target)
tensor(0.5790)
Example::

By setting aggregate=False the IoU score per prediction and target boxes is returned:

>>> import torch
>>> from torchmetrics.functional.detection import complete_intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> complete_intersection_over_union(preds, target, aggregate=False)
tensor([[ 0.6883, -0.2072, -0.3352],
        [-0.2217,  0.4881, -0.1913],
        [-0.3971, -0.1543,  0.5606]])

Distance Intersection Over Union (dIoU)

Module Interface

class torchmetrics.detection.diou.DistanceIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]

Computes Distance Intersection Over Union (DIoU).

As input to forward and update the metric accepts the following input:

  • preds (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes detection boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed detection classes for the boxes.

  • target (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes ground truth boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed ground truth classes for the boxes.

As output of forward and compute the metric returns the following output:

  • diou_dict: A dictionary containing the following key-values:

    • diou: (Tensor) with overall diou value over all classes and samples.

    • diou/cl_{cl}: (Tensor), if argument class_metrics=True

Parameters:
  • box_format (str) – Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].

  • iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • class_metrics (bool) – Option to enable per-class metrics for IoU. Has a performance impact.

  • respect_labels (bool) – Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou between all pairs of boxes.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.detection import DistanceIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = DistanceIntersectionOverUnion()
>>> metric(preds, target)
{'diou': tensor(0.8611)}
Raises:

ModuleNotFoundError – If torchvision is not installed with version 0.13.0 or newer.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting single value
>>> import torch
>>> from torchmetrics.detection import DistanceIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = DistanceIntersectionOverUnion()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/distance_intersection_over_union-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.detection import DistanceIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = lambda : [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = DistanceIntersectionOverUnion()
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds, target()))
>>> fig_, ax_ = metric.plot(vals)
_images/distance_intersection_over_union-2.png

Functional Interface

torchmetrics.functional.detection.diou.distance_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]

Compute Distance Intersection over Union (DIOU) between two sets of boxes.

Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.

Parameters:
  • preds (Tensor) – The input tensor containing the predicted bounding boxes.

  • target (Tensor) – The tensor containing the ground truth.

  • iou_threshold (Optional[float]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • replacement_val (float) – Value to replace values under the threshold with.

  • aggregate (bool) – Return the average value instead of the full matrix of values

Return type:

Tensor

Example::

By default diou is aggregated across all box pairs e.g. mean along the diagonal of the dIoU matrix:

>>> import torch
>>> from torchmetrics.functional.detection import distance_intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> distance_intersection_over_union(preds, target)
tensor(0.5793)
Example::

By setting aggregate=False the IoU score per prediction and target boxes is returned:

>>> import torch
>>> from torchmetrics.functional.detection import distance_intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> distance_intersection_over_union(preds, target, aggregate=False)
tensor([[ 0.6883, -0.2043, -0.3351],
        [-0.2214,  0.4886, -0.1913],
        [-0.3971, -0.1510,  0.5609]])

Generalized Intersection Over Union (gIoU)

Module Interface

class torchmetrics.detection.giou.GeneralizedIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]

Compute Generalized Intersection Over Union (GIoU).

As input to forward and update the metric accepts the following input:

  • preds (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes detection boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed detection classes for the boxes.

  • target (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes ground truth boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed ground truth classes for the boxes.

As output of forward and compute the metric returns the following output:

  • giou_dict: A dictionary containing the following key-values:

    • giou: (Tensor) with overall giou value over all classes and samples.

    • giou/cl_{cl}: (Tensor), if argument class metrics=True

Parameters:
  • box_format (str) – Input format of given boxes. Supported formats are [`xyxy`, `xywh`, `cxcywh`].

  • iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • class_metrics (bool) – Option to enable per-class metrics for IoU. Has a performance impact.

  • respect_labels (bool) – Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou between all pairs of boxes.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.detection import GeneralizedIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = GeneralizedIntersectionOverUnion()
>>> metric(preds, target)
{'giou': tensor(0.8613)}
Raises:

ModuleNotFoundError – If torchvision is not installed with version 0.8.0 or newer.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting single value
>>> import torch
>>> from torchmetrics.detection import GeneralizedIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = GeneralizedIntersectionOverUnion()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/generalized_intersection_over_union-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.detection import GeneralizedIntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = lambda : [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = GeneralizedIntersectionOverUnion()
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds, target()))
>>> fig_, ax_ = metric.plot(vals)
_images/generalized_intersection_over_union-2.png

Functional Interface

torchmetrics.functional.detection.giou.generalized_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]

Compute Generalized Intersection over Union (GIOU) between two sets of boxes.

Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.

Parameters:
  • preds (Tensor) – The input tensor containing the predicted bounding boxes.

  • target (Tensor) – The tensor containing the ground truth.

  • iou_threshold (Optional[float]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • replacement_val (float) – Value to replace values under the threshold with.

  • aggregate (bool) – Return the average value instead of the full matrix of values

Return type:

Tensor

Example::

By default giou is aggregated across all box pairs e.g. mean along the diagonal of the gIoU matrix:

>>> import torch
>>> from torchmetrics.functional.detection import generalized_intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> generalized_intersection_over_union(preds, target)
tensor(0.5638)
Example::

By setting aggregate=False the full IoU matrix is returned:

>>> import torch
>>> from torchmetrics.functional.detection import generalized_intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> generalized_intersection_over_union(preds, target, aggregate=False)
tensor([[ 0.6895, -0.4964, -0.4944],
        [-0.5105,  0.4673, -0.3434],
        [-0.6024, -0.4021,  0.5345]])

Intersection Over Union (IoU)

Module Interface

class torchmetrics.detection.iou.IntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]

Computes Intersection Over Union (IoU).

As input to forward and update the metric accepts the following input:

  • preds (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes detection boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels: IntTensor of shape (num_boxes) containing 0-indexed detection classes for the boxes.

  • target (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes ground truth boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed ground truth classes for the boxes.

As output of forward and compute the metric returns the following output:

  • iou_dict: A dictionary containing the following key-values:

    • iou: (Tensor)

    • iou/cl_{cl}: (Tensor), if argument class metrics=True

Parameters:
  • box_format (str) – Input format of given boxes. Supported formats are [`xyxy`, `xywh`, `cxcywh`].

  • iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • class_metrics (bool) – Option to enable per-class metrics for IoU. Has a performance impact.

  • respect_labels (bool) –

    Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou

    between all pairs of boxes.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example:

>>> import torch
>>> from torchmetrics.detection import IntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([
...             [296.55, 93.96, 314.97, 152.79],
...             [298.55, 98.96, 314.97, 151.79]]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = IntersectionOverUnion()
>>> metric(preds, target)
{'iou': tensor(0.8614)}

Example:

The metric can also return the score per class:

>>> import torch
>>> from torchmetrics.detection import IntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([
...             [296.55, 93.96, 314.97, 152.79],
...             [298.55, 98.96, 314.97, 151.79]]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([
...               [300.00, 100.00, 315.00, 150.00],
...               [300.00, 100.00, 315.00, 150.00]
...        ]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> metric = IntersectionOverUnion(class_metrics=True)
>>> metric(preds, target)
{'iou': tensor(0.7756), 'iou/cl_4': tensor(0.6898), 'iou/cl_5': tensor(0.8614)}
Raises:

ModuleNotFoundError – If torchvision is not installed with version 0.8.0 or newer.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.detection import IntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = IntersectionOverUnion()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/intersection_over_union-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.detection import IntersectionOverUnion
>>> preds = [
...    {
...        "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]),
...        "scores": torch.tensor([0.236, 0.56]),
...        "labels": torch.tensor([4, 5]),
...    }
... ]
>>> target = lambda : [
...    {
...        "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)),
...        "labels": torch.tensor([5]),
...    }
... ]
>>> metric = IntersectionOverUnion()
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds, target()))
>>> fig_, ax_ = metric.plot(vals)
_images/intersection_over_union-2.png

Functional Interface

torchmetrics.functional.detection.iou.intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]

Compute Intersection over Union between two sets of boxes.

Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.

Parameters:
  • preds (Tensor) – The input tensor containing the predicted bounding boxes.

  • target (Tensor) – The tensor containing the ground truth.

  • iou_threshold (Optional[float]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.

  • replacement_val (float) – Value to replace values under the threshold with.

  • aggregate (bool) – Return the average value instead of the full matrix of values

Return type:

Tensor

Example::

By default iou is aggregated across all box pairs e.g. mean along the diagonal of the IoU matrix:

>>> import torch
>>> from torchmetrics.functional.detection import intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> intersection_over_union(preds, target)
tensor(0.5879)
Example::

By setting aggregate=False the full IoU matrix is returned:

>>> import torch
>>> from torchmetrics.functional.detection import intersection_over_union
>>> preds = torch.tensor(
...     [
...         [296.55, 93.96, 314.97, 152.79],
...         [328.94, 97.05, 342.49, 122.98],
...         [356.62, 95.47, 372.33, 147.55],
...     ]
... )
>>> target = torch.tensor(
...     [
...         [300.00, 100.00, 315.00, 150.00],
...         [330.00, 100.00, 350.00, 125.00],
...         [350.00, 100.00, 375.00, 150.00],
...     ]
... )
>>> intersection_over_union(preds, target, aggregate=False)
tensor([[0.6898, 0.0000, 0.0000],
        [0.0000, 0.5086, 0.0000],
        [0.0000, 0.0000, 0.5654]])

Mean-Average-Precision (mAP)

Module Interface

class torchmetrics.detection.mean_ap.MeanAveragePrecision(box_format='xyxy', iou_type='bbox', iou_thresholds=None, rec_thresholds=None, max_detection_thresholds=None, class_metrics=False, extended_summary=False, average='macro', **kwargs)[source]

Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) for object detection predictions.

\[\text{mAP} = \frac{1}{n} \sum_{i=1}^{n} AP_i\]

where \(AP_i\) is the average precision for class \(i\) and \(n\) is the number of classes. The average precision is defined as the area under the precision-recall curve. For object detection the recall and precision are defined based on the intersection of union (IoU) between the predicted bounding boxes and the ground truth bounding boxes e.g. if two boxes have an IoU > t (with t being some threshold) they are considered a match and therefore considered a true positive. The precision is then defined as the number of true positives divided by the number of all detected boxes and the recall is defined as the number of true positives divided by the number of all ground boxes.

As input to forward and update the metric accepts the following input:

  • preds (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes detection boxes of the format specified in the constructor. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates, but can be changed using the box_format parameter. Only required when iou_type=”bbox”.

    • scores (Tensor): float tensor of shape (num_boxes) containing detection scores for the boxes.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed detection classes for the boxes.

    • masks (Tensor): boolean tensor of shape (num_boxes, image_height, image_width) containing boolean masks. Only required when iou_type=”segm”.

  • target (List): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:

    • boxes (Tensor): float tensor of shape (num_boxes, 4) containing num_boxes ground truth boxes of the format specified in the constructor. only required when iou_type=”bbox”. By default, this method expects (xmin, ymin, xmax, ymax) in absolute image coordinates.

    • labels (Tensor): integer tensor of shape (num_boxes) containing 0-indexed ground truth classes for the boxes.

    • masks (Tensor): boolean tensor of shape (num_boxes, image_height, image_width) containing boolean masks. Only required when iou_type=”segm”.

    • iscrowd (Tensor): integer tensor of shape (num_boxes) containing 0/1 values indicating whether the bounding box/masks indicate a crowd of objects. Value is optional, and if not provided it will automatically be set to 0.

    • area (Tensor): float tensor of shape (num_boxes) containing the area of the object. Value is optional, and if not provided will be automatically calculated based on the bounding box/masks provided. Only affects which samples contribute to the map_small, map_medium, map_large values

As output of forward and compute the metric returns the following output:

  • map_dict: A dictionary containing the following key-values:

    • map: (Tensor), global mean average precision

    • map_small: (Tensor), mean average precision for small objects

    • map_medium:(Tensor), mean average precision for medium objects

    • map_large: (Tensor), mean average precision for large objects

    • mar_1: (Tensor), mean average recall for 1 detection per image

    • mar_10: (Tensor), mean average recall for 10 detections per image

    • mar_100: (Tensor), mean average recall for 100 detections per image

    • mar_small: (Tensor), mean average recall for small objects

    • mar_medium: (Tensor), mean average recall for medium objects

    • mar_large: (Tensor), mean average recall for large objects

    • map_50: (Tensor) (-1 if 0.5 not in the list of iou thresholds), mean average precision at IoU=0.50

    • map_75: (Tensor) (-1 if 0.75 not in the list of iou thresholds), mean average precision at IoU=0.75

    • map_per_class: (Tensor) (-1 if class metrics are disabled), mean average precision per observed class

    • mar_100_per_class: (Tensor) (-1 if class metrics are disabled), mean average recall for 100 detections per image per observed class

    • classes (Tensor), list of all observed classes

For an example on how to use this metric check the torchmetrics mAP example.

Note

map score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.

Note

This metric utilizes the official pycocotools implementation as its backend. This means that the metric requires you to have pycocotools installed. In addition we require torchvision version 0.8.0 or newer. Please install with pip install torchmetrics[detection].

Parameters:
  • box_format (Literal['xyxy', 'xywh', 'cxcywh']) –

    Input format of given boxes. Supported formats are:

    • ’xyxy’: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.

    • ’xywh’ : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. This is the default format used by pycoco and all input formats will be converted to this.

    • ’cxcywh’: boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height.

  • iou_type (Union[Literal['bbox', 'segm'], Tuple[str]]) – Type of input (either masks or bounding-boxes) used for computing IOU. Supported IOU types are "bbox" or "segm" or both as a tuple.

  • iou_thresholds (Optional[List[float]]) – IoU thresholds for evaluation. If set to None it corresponds to the stepped range [0.5,...,0.95] with step 0.05. Else provide a list of floats.

  • rec_thresholds (Optional[List[float]]) – Recall thresholds for evaluation. If set to None it corresponds to the stepped range [0,...,1] with step 0.01. Else provide a list of floats.

  • max_detection_thresholds (Optional[List[int]]) – Thresholds on max detections per image. If set to None will use thresholds [1, 10, 100]. Else, please provide a list of ints.

  • class_metrics (bool) – Option to enable per-class metrics for mAP and mAR_100. Has a performance impact that scales linearly with the number of classes in the dataset.

  • extended_summary (bool) –

    Option to enable extended summary with additional metrics including IOU, precision and recall. The output dictionary will contain the following extra key-values:

    • ious: a dictionary containing the IoU values for every image/class combination e.g. ious[(0,0)] would contain the IoU for image 0 and class 0. Each value is a tensor with shape (n,m) where n is the number of detections and m is the number of ground truth boxes for that image/class combination.

    • precision: a tensor of shape (TxRxKxAxM) containing the precision values. Here T is the number of IoU thresholds, R is the number of recall thresholds, K is the number of classes, A is the number of areas and M is the number of max detections per image.

    • recall: a tensor of shape (TxKxAxM) containing the recall values. Here T is the number of IoU thresholds, K is the number of classes, A is the number of areas and M is the number of max detections per image.

  • average (Literal['macro', 'micro']) – Method for averaging scores over labels. Choose between “macro”” and “micro”. Default is “macro”

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ModuleNotFoundError – If pycocotools is not installed

  • ModuleNotFoundError – If torchvision is not installed or version installed is lower than 0.8.0

  • ValueError – If box_format is not one of "xyxy", "xywh" or "cxcywh"

  • ValueError – If iou_type is not one of "bbox" or "segm"

  • ValueError – If iou_thresholds is not None or a list of floats

  • ValueError – If rec_thresholds is not None or a list of floats

  • ValueError – If max_detection_thresholds is not None or a list of ints

  • ValueError – If class_metrics is not a boolean

Example:

Basic example for when `iou_type="bbox"`. In this case the ``boxes`` key is required in the input dictionaries,
in addition to the ``scores`` and ``labels`` keys.

>>> from torch import tensor
>>> from torchmetrics.detection import MeanAveragePrecision
>>> preds = [
...   dict(
...     boxes=tensor([[258.0, 41.0, 606.0, 285.0]]),
...     scores=tensor([0.536]),
...     labels=tensor([0]),
...   )
... ]
>>> target = [
...   dict(
...     boxes=tensor([[214.0, 41.0, 562.0, 285.0]]),
...     labels=tensor([0]),
...   )
... ]
>>> metric = MeanAveragePrecision(iou_type="bbox")
>>> metric.update(preds, target)
>>> from pprint import pprint
>>> pprint(metric.compute())
{'classes': tensor(0, dtype=torch.int32),
 'map': tensor(0.6000),
 'map_50': tensor(1.),
 'map_75': tensor(1.),
 'map_large': tensor(0.6000),
 'map_medium': tensor(-1.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(-1.),
 'mar_1': tensor(0.6000),
 'mar_10': tensor(0.6000),
 'mar_100': tensor(0.6000),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(0.6000),
 'mar_medium': tensor(-1.),
 'mar_small': tensor(-1.)}

Example:

Basic example for when `iou_type="segm"`. In this case the ``masks`` key is required in the input dictionaries,
in addition to the ``scores`` and ``labels`` keys.

>>> from torch import tensor
>>> from torchmetrics.detection import MeanAveragePrecision
>>> mask_pred = [
...   [0, 0, 0, 0, 0],
...   [0, 0, 1, 1, 0],
...   [0, 0, 1, 1, 0],
...   [0, 0, 0, 0, 0],
...   [0, 0, 0, 0, 0],
... ]
>>> mask_tgt = [
...   [0, 0, 0, 0, 0],
...   [0, 0, 1, 0, 0],
...   [0, 0, 1, 1, 0],
...   [0, 0, 1, 0, 0],
...   [0, 0, 0, 0, 0],
... ]
>>> preds = [
...   dict(
...     masks=tensor([mask_pred], dtype=torch.bool),
...     scores=tensor([0.536]),
...     labels=tensor([0]),
...   )
... ]
>>> target = [
...   dict(
...     masks=tensor([mask_tgt], dtype=torch.bool),
...     labels=tensor([0]),
...   )
... ]
>>> metric = MeanAveragePrecision(iou_type="segm")
>>> metric.update(preds, target)
>>> from pprint import pprint
>>> pprint(metric.compute())
{'classes': tensor(0, dtype=torch.int32),
 'map': tensor(0.2000),
 'map_50': tensor(1.),
 'map_75': tensor(0.),
 'map_large': tensor(-1.),
 'map_medium': tensor(-1.),
 'map_per_class': tensor(-1.),
 'map_small': tensor(0.2000),
 'mar_1': tensor(0.2000),
 'mar_10': tensor(0.2000),
 'mar_100': tensor(0.2000),
 'mar_100_per_class': tensor(-1.),
 'mar_large': tensor(-1.),
 'mar_medium': tensor(-1.),
 'mar_small': tensor(0.2000)}
static coco_to_tm(coco_preds, coco_target, iou_type='bbox')[source]

Utility function for converting .json coco format files to the input format of this metric.

The function accepts a file for the predictions and a file for the target in coco format and converts them to a list of dictionaries containing the boxes, labels and scores in the input format of this metric.

Parameters:
  • coco_preds (str) – Path to the json file containing the predictions in coco format

  • coco_target (str) – Path to the json file containing the targets in coco format

  • iou_type (Union[Literal['bbox', 'segm'], List[str]]) – Type of input, either bbox for bounding boxes or segm for segmentation masks

Return type:

Tuple[List[Dict[str, Tensor]], List[Dict[str, Tensor]]]

Returns:

A tuple containing the predictions and targets in the input format of this metric. Each element of the tuple is a list of dictionaries containing the boxes, labels and scores.

Example

>>> # File formats are defined at https://cocodataset.org/#format-data
>>> # Example files can be found at
>>> # https://github.com/cocodataset/cocoapi/tree/master/results
>>> from torchmetrics.detection import MeanAveragePrecision
>>> preds, target = MeanAveragePrecision.coco_to_tm(
...   "instances_val2014_fakebbox100_results.json.json",
...   "val2014_fake_eval_res.txt.json"
...   iou_type="bbox"
... )  
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import tensor
>>> from torchmetrics.detection.mean_ap import MeanAveragePrecision
>>> preds = [dict(
...     boxes=tensor([[258.0, 41.0, 606.0, 285.0]]),
...     scores=tensor([0.536]),
...     labels=tensor([0]),
... )]
>>> target = [dict(
...     boxes=tensor([[214.0, 41.0, 562.0, 285.0]]),
...     labels=tensor([0]),
... )]
>>> metric = MeanAveragePrecision()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/mean_average_precision-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.detection.mean_ap import MeanAveragePrecision
>>> preds = lambda: [dict(
...     boxes=torch.tensor([[258.0, 41.0, 606.0, 285.0]]) + torch.randint(10, (1,4)),
...     scores=torch.tensor([0.536]) + 0.1*torch.rand(1),
...     labels=torch.tensor([0]),
... )]
>>> target = [dict(
...     boxes=torch.tensor([[214.0, 41.0, 562.0, 285.0]]),
...     labels=torch.tensor([0]),
... )]
>>> metric = MeanAveragePrecision()
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds(), target))
>>> fig_, ax_ = metric.plot(vals)
_images/mean_average_precision-2.png
tm_to_coco(name='tm_map_input')[source]

Utility function for converting the input for this metric to coco format and saving it to a json file.

This function should be used after calling .update(…) or .forward(…) on all data that should be written to the file, as the input is then internally cached. The function then converts to information to coco format a writes it to json files.

Parameters:

name (str) – Name of the output file, which will be appended with “_preds.json” and “_target.json”

Return type:

None

Example

>>> from torch import tensor
>>> from torchmetrics.detection import MeanAveragePrecision
>>> preds = [
...   dict(
...     boxes=tensor([[258.0, 41.0, 606.0, 285.0]]),
...     scores=tensor([0.536]),
...     labels=tensor([0]),
...   )
... ]
>>> target = [
...   dict(
...     boxes=tensor([[214.0, 41.0, 562.0, 285.0]]),
...     labels=tensor([0]),
...   )
... ]
>>> metric = MeanAveragePrecision()
>>> metric.update(preds, target)
>>> metric.tm_to_coco("tm_map_input")  

Modified Panoptic Quality

Module Interface

class torchmetrics.detection.ModifiedPanopticQuality(things, stuffs, allow_unknown_preds_category=False, **kwargs)[source]

Compute Modified Panoptic Quality for panoptic segmentations.

The metric was introduced in Seamless Scene Segmentation paper, and is an adaptation of the original Panoptic Quality where the metric for a stuff class is computed as

\[PQ^{\dagger}_c = \frac{IOU_c}{|S_c|}\]

where \(IOU_c\) is the sum of the intersection over union of all matching segments for a given class, and \(|S_c|\) is the overall number of segments in the ground truth for that class.

Parameters:
  • things (Collection[int]) – Set of category_id for countable things.

  • stuffs (Collection[int]) – Set of category_id for uncountable stuffs.

  • allow_unknown_preds_category (bool) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.

Raises:
  • ValueError – If things, stuffs have at least one common category_id.

  • TypeError – If things, stuffs contain non-integer category_id.

Example

>>> from torch import tensor
>>> from torchmetrics.detection import ModifiedPanopticQuality
>>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
>>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
>>> pq_modified = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> pq_modified(preds, target)
tensor(0.7667, dtype=torch.float64)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import tensor
>>> from torchmetrics.detection import ModifiedPanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [7, 0], [6, 0], [1, 0]],
...                  [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [1, 0]],
...                   [[0, 1], [7, 0], [1, 0], [1, 0]],
...                   [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/modified_panoptic_quality-1.png
>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics.detection import ModifiedPanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [7, 0], [6, 0], [1, 0]],
...                  [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [1, 0]],
...                   [[0, 1], [7, 0], [1, 0], [1, 0]],
...                   [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(vals)
_images/modified_panoptic_quality-2.png

Functional Interface

torchmetrics.functional.detection.modified_panoptic_quality(preds, target, things, stuffs, allow_unknown_preds_category=False)[source]

Compute Modified Panoptic Quality for panoptic segmentations.

The metric was introduced in Seamless Scene Segmentation paper, and is an adaptation of the original Panoptic Quality where the metric for a stuff class is computed as

\[PQ^{\dagger}_c = \frac{IOU_c}{|S_c|}\]

where \(IOU_c\) is the sum of the intersection over union of all matching segments for a given class, and \(|S_c|\) is the overall number of segments in the ground truth for that class.

Parameters:
  • preds (Tensor) – torch tensor with panoptic detection of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.

  • target (Tensor) – torch tensor with ground truth of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.

  • things (Collection[int]) – Set of category_id for countable things.

  • stuffs (Collection[int]) – Set of category_id for uncountable stuffs.

  • allow_unknown_preds_category (bool) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.

Raises:
  • ValueError – If things, stuffs have at least one common category_id.

  • TypeError – If things, stuffs contain non-integer category_id.

  • TypeError – If preds or target is not an torch.Tensor.

  • ValueError – If preds or target has different shape.

  • ValueError – If preds has less than 3 dimensions.

  • ValueError – If the final dimension of preds has size != 2.

Return type:

Tensor

Example

>>> from torch import tensor
>>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
>>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
>>> modified_panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
tensor(0.7667, dtype=torch.float64)

Panoptic Quality

Module Interface

class torchmetrics.detection.PanopticQuality(things, stuffs, allow_unknown_preds_category=False, **kwargs)[source]

Compute the Panoptic Quality for panoptic segmentations.

\[PQ = \frac{IOU}{TP + 0.5 FP + 0.5 FN}\]

where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives, the number of true postitives, false positives and false negatives. This metric is inspired by the PQ implementation of panopticapi, a standard implementation for the PQ metric for panoptic segmentation.

Parameters:
  • things (Collection[int]) – Set of category_id for countable things.

  • stuffs (Collection[int]) – Set of category_id for uncountable stuffs.

  • allow_unknown_preds_category (bool) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.

Raises:
  • ValueError – If things, stuffs have at least one common category_id.

  • TypeError – If things, stuffs contain non-integer category_id.

Example

>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [7, 0], [6, 0], [1, 0]],
...                  [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [1, 0]],
...                   [[0, 1], [7, 0], [1, 0], [1, 0]],
...                   [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure object and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [7, 0], [6, 0], [1, 0]],
...                  [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [1, 0]],
...                   [[0, 1], [7, 0], [1, 0], [1, 0]],
...                   [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/panoptic_quality-1.png
>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics.detection import PanopticQuality
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [7, 0], [6, 0], [1, 0]],
...                  [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [1, 0]],
...                   [[0, 1], [7, 0], [1, 0], [1, 0]],
...                   [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> vals = []
>>> for _ in range(20):
...     vals.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(vals)
_images/panoptic_quality-2.png

Functional Interface

torchmetrics.functional.detection.panoptic_quality(preds, target, things, stuffs, allow_unknown_preds_category=False)[source]

Compute Panoptic Quality for panoptic segmentations.

\[PQ = \frac{IOU}{TP + 0.5 FP + 0.5 FN}\]

where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives, the number of true postitives, false positives and false negatives. This metric is inspired by the PQ implementation of panopticapi, a standard implementation for the PQ metric for object detection.

Parameters:
  • preds (Tensor) – torch tensor with panoptic detection of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.

  • target (Tensor) – torch tensor with ground truth of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.

  • things (Collection[int]) – Set of category_id for countable things.

  • stuffs (Collection[int]) – Set of category_id for uncountable stuffs.

  • allow_unknown_preds_category (bool) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.

Raises:
  • ValueError – If things, stuffs have at least one common category_id.

  • TypeError – If things, stuffs contain non-integer category_id.

  • TypeError – If preds or target is not an torch.Tensor.

  • ValueError – If preds or target has different shape.

  • ValueError – If preds has less than 3 dimensions.

  • ValueError – If the final dimension of preds has size != 2.

Return type:

Tensor

Example

>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [0, 0], [6, 0], [0, 1]],
...                  [[0, 0], [7, 0], [6, 0], [1, 0]],
...                  [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [0, 1]],
...                   [[0, 1], [0, 1], [6, 0], [1, 0]],
...                   [[0, 1], [7, 0], [1, 0], [1, 0]],
...                   [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7})
tensor(0.5463, dtype=torch.float64)

Error Relative Global Dim. Synthesis (ERGAS)

Module Interface

class torchmetrics.image.ErrorRelativeGlobalDimensionlessSynthesis(ratio=4, reduction='elementwise_mean', **kwargs)[source]

Calculate Relative dimensionless global error synthesis (ERGAS).

This metric is used to calculate the accuracy of Pan sharpened image considering normalized average error of each band of the result image.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output

  • ergas (Tensor): if reduction!='none' returns float scalar tensor with average ERGAS value over sample else returns tensor of shape (N,) with ERGAS values per sample

Parameters:
  • ratio (Union[int, float]) – ratio of high resolution to low resolution

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> ergas = ErrorRelativeGlobalDimensionlessSynthesis()
>>> torch.round(ergas(preds, target))
tensor(154.)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/error_relative_global_dimensionless_synthesis-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = ErrorRelativeGlobalDimensionlessSynthesis()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/error_relative_global_dimensionless_synthesis-2.png

Functional Interface

torchmetrics.functional.image.error_relative_global_dimensionless_synthesis(preds, target, ratio=4, reduction='elementwise_mean')[source]

Erreur Relative Globale Adimensionnelle de Synthèse.

Parameters:
  • preds (Tensor) – estimated image

  • target (Tensor) – ground truth image

  • ratio (Union[int, float]) – ratio of high resolution to low resolution

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

Return type:

Tensor

Returns:

Tensor with RelativeG score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

Example

>>> from torchmetrics.functional.image import error_relative_global_dimensionless_synthesis
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 1, 16, 16], generator=gen)
>>> target = preds * 0.75
>>> ergds = error_relative_global_dimensionless_synthesis(preds, target)
>>> torch.round(ergds)
tensor(154.)

References

[1] Qian Du; Nicholas H. Younan; Roger King; Vijay P. Shah, “On the Performance Evaluation of Pan-Sharpening Techniques” in IEEE Geoscience and Remote Sensing Letters, vol. 4, no. 4, pp. 518-522, 15 October 2007, doi: 10.1109/LGRS.2007.896328.

Frechet Inception Distance (FID)

Module Interface

class torchmetrics.image.fid.FrechetInceptionDistance(feature=2048, reset_real_features=True, normalize=False, **kwargs)[source]

Calculate Fréchet inception distance (FID) which is used to access the quality of generated images.

\[FID = \|\mu - \mu_w\|^2 + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})\]

where \(\mathcal{N}(\mu, \Sigma)\) is the multivariate normal distribution estimated from Inception v3 (fid ref1) features calculated on real life images and \(\mathcal{N}(\mu_w, \Sigma_w)\) is the multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images. The metric was originally proposed in fid ref1.

Using the default feature extraction (Inception v3 using the original weights from fid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape (3xHxW). If argument normalize is True images are expected to be dtype float and have values in the [0,1] range, else if normalize is set to False images are expected to have dtype uint8 and take values in the [0, 255] range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flag real determines if the images should update the statistics of the real distribution or the fake distribution.

This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric that you calculate using torch.float64 (default is torch.float32) which can be set using the .set_dtype method of the metric.

Note

using this metrics requires you to have torch 1.9 or higher installed

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

As input to forward and update the metric accepts the following input

  • imgs (Tensor): tensor with images feed to the feature extractor with

  • real (bool): bool indicating if imgs belong to the real or the fake distribution

As output of forward and compute the metric returns the following output

  • fid (Tensor): float scalar tensor with mean FID value over samples

Parameters:
  • feature (Union[int, Module]) –

    Either an integer or nn.Module:

    • an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an (N,d) matrix where N is the batch size and d is the feature size.

  • reset_real_features (bool) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can be cached them to avoid recomputing them which is costly. Set this to False if your dataset does not change.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If torch version is lower than 1.9

  • ModuleNotFoundError – If feature is set to an int (default settings) and torch-fidelity is not installed

  • ValueError – If feature is set to an int not in [64, 192, 768, 2048]

  • TypeError – If feature is not an str, int or torch.nn.Module

  • ValueError – If reset_real_features is not an bool

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.fid import FrechetInceptionDistance
>>> fid = FrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> fid.update(imgs_dist1, real=True)
>>> fid.update(imgs_dist2, real=False)
>>> fid.compute()
tensor(12.7202)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.fid import FrechetInceptionDistance
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> metric = FrechetInceptionDistance(feature=64)
>>> metric.update(imgs_dist1, real=True)
>>> metric.update(imgs_dist2, real=False)
>>> fig_, ax_ = metric.plot()
_images/frechet_inception_distance-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.fid import FrechetInceptionDistance
>>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> metric = FrechetInceptionDistance(feature=64)
>>> values = [ ]
>>> for _ in range(3):
...     metric.update(imgs_dist1(), real=True)
...     metric.update(imgs_dist2(), real=False)
...     values.append(metric.compute())
...     metric.reset()
>>> fig_, ax_ = metric.plot(values)
_images/frechet_inception_distance-2.png
reset()[source]

Reset metric states.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as torch.dtype or string

Return type:

Metric

Image Gradients

Functional Interface

torchmetrics.functional.image.image_gradients(img)[source]

Compute Gradient Computation of Image of a given image using finite difference.

Parameters:

img (Tensor) – An (N, C, H, W) input tensor where C is the number of image channels

Return type:

Tuple[Tensor, Tensor]

Returns:

Tuple of (dy, dx) with each gradient of shape [N, C, H, W]

Raises:

Example

>>> from torchmetrics.functional.image import image_gradients
>>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)
>>> image = torch.reshape(image, (1, 1, 5, 5))
>>> dy, dx = image_gradients(image)
>>> dy[0, 0, :, :]
tensor([[5., 5., 5., 5., 5.],
        [5., 5., 5., 5., 5.],
        [5., 5., 5., 5., 5.],
        [5., 5., 5., 5., 5.],
        [0., 0., 0., 0., 0.]])

Note

The implementation follows the 1-step finite difference method as followed by the TF implementation. The values are organized such that the gradient of [I(x+1, y)-[I(x, y)]] are at the (x, y) location

Inception Score

Module Interface

class torchmetrics.image.inception.InceptionScore(feature='logits_unbiased', splits=10, normalize=False, **kwargs)[source]

Calculate the Inception Score (IS) which is used to access how realistic generated images are.

\[IS = exp(\mathbb{E}_x KL(p(y | x ) || p(y)))\]

where \(KL(p(y | x) || p(y))\) is the KL divergence between the conditional distribution \(p(y|x)\) and the margianl distribution \(p(y)\). Both the conditional and marginal distribution is calculated from features extracted from the images. The score is calculated on random splits of the images such that both a mean and standard deviation of the score are returned. The metric was originally proposed in inception ref1.

Using the default feature extraction (Inception v3 using the original weights from inception ref2), the input is expected to be mini-batches of 3-channel RGB images of shape (3xHxW). If argument normalize is True images are expected to be dtype float and have values in the [0,1] range, else if normalize is set to False images are expected to have dtype uint8 and take values in the [0, 255] range. All images will be resized to 299 x 299 which is the size of the original training data.

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

As input to forward and update the metric accepts the following input

  • imgs (Tensor): tensor with images feed to the feature extractor

As output of forward and compute the metric returns the following output

  • fid (Tensor): float scalar tensor with mean FID value over samples

Parameters:
  • feature (Union[str, int, Module]) –

    Either an str, integer or nn.Module:

    • an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an (N,d) matrix where N is the batch size and d is the feature size.

  • splits (int) – integer determining how many splits the inception score calculation should be split among

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If feature is set to an str or int and torch-fidelity is not installed

  • ValueError – If feature is set to an str or int and not one of ('logits_unbiased', 64, 192, 768, 2048)

  • TypeError – If feature is not an str, int or torch.nn.Module

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.inception import InceptionScore
>>> inception = InceptionScore()
>>> # generate some images
>>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> inception.update(imgs)
>>> inception.compute()
(tensor(1.0544), tensor(0.0117))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.inception import InceptionScore
>>> metric = InceptionScore()
>>> metric.update(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))
>>> fig_, ax_ = metric.plot()  # the returned plot only shows the mean value by default
_images/inception_score-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.inception import InceptionScore
>>> metric = InceptionScore()
>>> values = [ ]
>>> for _ in range(3):
...     # we index by 0 such that only the mean value is plotted
...     values.append(metric(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))[0])
>>> fig_, ax_ = metric.plot(values)
_images/inception_score-2.png

Kernel Inception Distance

Module Interface

class torchmetrics.image.kid.KernelInceptionDistance(feature=2048, subsets=100, subset_size=1000, degree=3, gamma=None, coef=1.0, reset_real_features=True, normalize=False, **kwargs)[source]

Calculate Kernel Inception Distance (KID) which is used to access the quality of generated images.

\[KID = MMD(f_{real}, f_{fake})^2\]

where \(MMD\) is the maximum mean discrepancy and \(I_{real}, I_{fake}\) are extracted features from real and fake images, see kid ref1 for more details. In particular, calculating the MMD requires the evaluation of a polynomial kernel function \(k\)

\[k(x,y) = (\gamma * x^T y + coef)^{degree}\]

which controls the distance between two features. In practise the MMD is calculated over a number of subsets to be able to both get the mean and standard deviation of KID.

Using the default feature extraction (Inception v3 using the original weights from kid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape (3xHxW). If argument normalize is True images are expected to be dtype float and have values in the [0,1] range, else if normalize is set to False images are expected to have dtype uint8 and take values in the [0, 255] range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flag real determines if the images should update the statistics of the real distribution or the fake distribution.

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

As input to forward and update the metric accepts the following input

  • imgs (Tensor): tensor with images feed to the feature extractor of shape (N,C,H,W)

  • real (bool): bool indicating if imgs belong to the real or the fake distribution

As output of forward and compute the metric returns the following output

  • kid_mean (Tensor): float scalar tensor with mean value over subsets

  • kid_std (Tensor): float scalar tensor with mean value over subsets

Parameters:
  • feature (Union[str, int, Module]) –

    Either an str, integer or nn.Module:

    • an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an (N,d) matrix where N is the batch size and d is the feature size.

  • subsets (int) – Number of subsets to calculate the mean and standard deviation scores over

  • subset_size (int) – Number of randomly picked samples in each subset

  • degree (int) – Degree of the polynomial kernel function

  • gamma (Optional[float]) – Scale-length of polynomial kernel. If set to None will be automatically set to the feature size

  • coef (float) – Bias term in the polynomial kernel.

  • reset_real_features (bool) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this to False if your dataset does not change.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If feature is set to an int (default settings) and torch-fidelity is not installed

  • ValueError – If feature is set to an int not in (64, 192, 768, 2048)

  • ValueError – If subsets is not an integer larger than 0

  • ValueError – If subset_size is not an integer larger than 0

  • ValueError – If degree is not an integer larger than 0

  • ValueError – If gamma is niether None or a float larger than 0

  • ValueError – If coef is not an float larger than 0

  • ValueError – If reset_real_features is not an bool

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> kid = KernelInceptionDistance(subset_size=50)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> kid.update(imgs_dist1, real=True)
>>> kid.update(imgs_dist2, real=False)
>>> kid.compute()
(tensor(0.0337), tensor(0.0023))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8)
>>> metric = KernelInceptionDistance(subsets=3, subset_size=20)
>>> metric.update(imgs_dist1, real=True)
>>> metric.update(imgs_dist2, real=False)
>>> fig_, ax_ = metric.plot()
_images/kernel_inception_distance-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.kid import KernelInceptionDistance
>>> imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8)
>>> metric = KernelInceptionDistance(subsets=3, subset_size=20)
>>> values = [ ]
>>> for _ in range(3):
...     metric.update(imgs_dist1(), real=True)
...     metric.update(imgs_dist2(), real=False)
...     values.append(metric.compute()[0])
...     metric.reset()
>>> fig_, ax_ = metric.plot(values)
_images/kernel_inception_distance-2.png
reset()[source]

Reset metric states.

Return type:

None

Learned Perceptual Image Patch Similarity (LPIPS)

Module Interface

class torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type='alex', reduction='mean', normalize=False, **kwargs)[source]

The Learned Perceptual Image Patch Similarity (LPIPS_) calculates the perceptual similarity between two images.

LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar.

Both input image patches are expected to have shape (N, 3, H, W). The minimum size of H, W depends on the chosen backbone (see net_type arg).

Note

using this metrics requires you to have lpips package installed. Either install as pip install torchmetrics[image] or pip install lpips

Note

this metric is not scriptable when using torch<1.8. Please update your pytorch installation if this is a issue.

As input to forward and update the metric accepts the following input

  • img1 (Tensor): tensor with images of shape (N, 3, H, W)

  • img2 (Tensor): tensor with images of shape (N, 3, H, W)

As output of forward and compute the metric returns the following output

  • lpips (Tensor): returns float scalar tensor with average LPIPS value over samples

Parameters:
  • net_type (Literal['vgg', 'alex', 'squeeze']) – str indicating backbone network type to use. Choose between ‘alex’, ‘vgg’ or ‘squeeze’

  • reduction (Literal['sum', 'mean']) – str indicating how to reduce over the batch dimension. Choose between ‘sum’ or ‘mean’.

  • normalize (bool) – by default this is False meaning that the input is expected to be in the [-1,1] range. If set to True will instead expect input to be in the [0,1] range.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
>>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
>>> # LPIPS needs the images to be in the [-1, 1] range.
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> lpips(img1, img2)
tensor(0.1046, grad_fn=<SqueezeBackward0>)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
>>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
>>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
>>> fig_, ax_ = metric.plot()
_images/learned_perceptual_image_patch_similarity-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
>>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)))
>>> fig_, ax_ = metric.plot(values)
_images/learned_perceptual_image_patch_similarity-2.png

Functional Interface

torchmetrics.functional.image.learned_perceptual_image_patch_similarity(img1, img2, net_type='alex', reduction='mean', normalize=False)[source]

The Learned Perceptual Image Patch Similarity (LPIPS_) calculates the perceptual similarity between two images.

LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar.

Both input image patches are expected to have shape (N, 3, H, W). The minimum size of H, W depends on the chosen backbone (see net_type arg).

Parameters:
  • img1 (Tensor) – first set of images

  • img2 (Tensor) – second set of images

  • net_type (Literal['alex', 'vgg', 'squeeze']) – str indicating backbone network type to use. Choose between ‘alex’, ‘vgg’ or ‘squeeze’

  • reduction (Literal['sum', 'mean']) – str indicating how to reduce over the batch dimension. Choose between ‘sum’ or ‘mean’.

  • normalize (bool) – by default this is False meaning that the input is expected to be in the [-1,1] range. If set to True will instead expect input to be in the [0,1] range.

Return type:

Tensor

Example

>>> import torch
>>> _ = torch.manual_seed(123)
>>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
>>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
>>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze')
tensor(0.1008, grad_fn=<DivBackward0>)

Memorization-Informed Frechet Inception Distance (MiFID)

Module Interface

class torchmetrics.image.mifid.MemorizationInformedFrechetInceptionDistance(feature=2048, reset_real_features=True, normalize=False, cosine_distance_eps=0.1, **kwargs)[source]

Calculate Memorization-Informed Frechet Inception Distance (MIFID).

MIFID is a improved variation of the Frechet Inception Distance (FID) that penalizes memorization of the training set by the generator. It is calculated as

\[MIFID = \frac{FID(F_{real}, F_{fake})}{M(F_{real}, F_{fake})}\]

where \(FID\) is the normal FID score and \(M\) is the memorization penalty. The memorization penalty essentially corresponds to the average minimum cosine distance between the features of the real and fake distribution.

Using the default feature extraction (Inception v3 using the original weights from fid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape (3 x H x W). If argument normalize is True images are expected to be dtype float and have values in the [0, 1] range, else if normalize is set to False images are expected to have dtype uint8 and take values in the [0, 255] range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flag real determines if the images should update the statistics of the real distribution or the fake distribution.

Note

using this metrics requires you to have scipy install. Either install as pip install torchmetrics[image] or pip install scipy

Note

using this metric with the default feature extractor requires that torch-fidelity is installed. Either install as pip install torchmetrics[image] or pip install torch-fidelity

As input to forward and update the metric accepts the following input

  • imgs (Tensor): tensor with images feed to the feature extractor with

  • real (bool): bool indicating if imgs belong to the real or the fake distribution

As output of forward and compute the metric returns the following output

  • mifid (Tensor): float scalar tensor with mean MIFID value over samples

Parameters:
  • feature (Union[int, Module]) –

    Either an integer or nn.Module:

    • an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: 64, 192, 768, 2048

    • an nn.Module for using a custom feature extractor. Expects that its forward method returns an (N,d) matrix where N is the batch size and d is the feature size.

  • reset_real_features (bool) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can be cached them to avoid recomputing them which is costly. Set this to False if your dataset does not change.

  • cosine_distance_eps (float) – Epsilon value for the cosine distance. If the cosine distance is larger than this value it is set to 1 and thus ignored in the MIFID calculation.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • RuntimeError – If torch is version less than 1.10

  • ValueError – If feature is set to an int and torch-fidelity is not installed

  • ValueError – If feature is set to an int not in [64, 192, 768, 2048]

  • TypeError – If feature is not an str, int or torch.nn.Module

  • ValueError – If reset_real_features is not an bool

Example::
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> # generate two slightly overlapping image intensity distributions
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> mifid.update(imgs_dist1, real=True)
>>> mifid.update(imgs_dist2, real=False)
>>> mifid.compute()
tensor(3003.3691)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> metric = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> metric.update(imgs_dist1, real=True)
>>> metric.update(imgs_dist2, real=False)
>>> fig_, ax_ = metric.plot()
_images/mifid-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
>>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
>>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
>>> metric = MemorizationInformedFrechetInceptionDistance(feature=64)
>>> values = [ ]
>>> for _ in range(3):
...     metric.update(imgs_dist1(), real=True)
...     metric.update(imgs_dist2(), real=False)
...     values.append(metric.compute())
...     metric.reset()
>>> fig_, ax_ = metric.plot(values)
_images/mifid-2.png
reset()[source]

Reset metric states.

Return type:

None

Multi-Scale SSIM

Module Interface

class torchmetrics.image.MultiScaleStructuralSimilarityIndexMeasure(gaussian_kernel=True, kernel_size=11, sigma=1.5, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize='relu', **kwargs)[source]

Compute MultiScaleSSIM, Multi-scale Structural Similarity Index Measure.

This metric is is a generalization of Structural Similarity Index Measure by incorporating image details at different resolution scores.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output

  • msssim (Tensor): if reduction!='none' returns float scalar tensor with average MSSSIM value over sample else returns tensor of shape (N,) with SSIM values per sample

Parameters:
  • gaussian_kernel (bool) – If True (default), a gaussian kernel is used, if false a uniform kernel is used

  • kernel_size (Union[int, Sequence[int]]) – size of the gaussian kernel

  • sigma (Union[float, Sequence[float]]) – Standard deviation of the gaussian kernel

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • data_range (Union[float, Tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. The data_range must be given when dim is not None.

  • k1 (float) – Parameter of structural similarity index measure.

  • k2 (float) – Parameter of structural similarity index measure.

  • betas (Tuple[float, ...]) – Exponent parameters for individual similarities and contrastive sensitivies returned by different image resolutions.

  • normalize (Literal['relu', 'simple', None]) – When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use normalizes to improve the training stability. This normalize argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

Tensor with Multi-Scale SSIM score

Raises:
  • ValueError – If kernel_size is not an int or a Sequence of ints with size 2 or 3.

  • ValueError – If betas is not a tuple of floats with lengt 2.

  • ValueError – If normalize is neither None, ReLU nor simple.

Example

>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> import torch
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9627)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/multi_scale_structural_similarity-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/multi_scale_structural_similarity-2.png

Functional Interface

torchmetrics.functional.image.multiscale_structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize='relu')[source]

Compute MultiScaleSSIM, Multi-scale Structual Similarity Index Measure.

This metric is a generalization of Structual Similarity Index Measure by incorporating image details at different resolution scores.

Parameters:
  • preds (Tensor) – Predictions from model of shape [N, C, H, W]

  • target (Tensor) – Ground truth values of shape [N, C, H, W]

  • gaussian_kernel (bool) – If true, a gaussian kernel is used, if false a uniform kernel is used

  • sigma (Union[float, Sequence[float]]) – Standard deviation of the gaussian kernel

  • kernel_size (Union[int, Sequence[int]]) – size of the gaussian kernel

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • data_range (Union[float, Tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.

  • k1 (float) – Parameter of structural similarity index measure.

  • k2 (float) – Parameter of structural similarity index measure.

  • betas (Tuple[float, ...]) – Exponent parameters for individual similarities and contrastive sensitivies returned by different image resolutions.

  • normalize (Optional[Literal['relu', 'simple']]) – When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the training stability. This normalize argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.

Return type:

Tensor

Returns:

Tensor with Multi-Scale SSIM score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

  • ValueError – If the length of kernel_size or sigma is not 2.

  • ValueError – If one of the elements of kernel_size is not an odd positive number.

  • ValueError – If one of the elements of sigma is not a positive number.

Example

>>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([3, 3, 256, 256], generator=gen)
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)

References

[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. Bovik MultiScaleSSIM

Peak Signal-to-Noise Ratio (PSNR)

Module Interface

class torchmetrics.image.PeakSignalNoiseRatio(data_range=None, base=10.0, reduction='elementwise_mean', dim=None, **kwargs)[source]

Compute Peak Signal-to-Noise Ratio (PSNR).

\[\text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right)\]

Where \(\text{MSE}\) denotes the mean-squared-error function.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,C,H,W)

  • target (Tensor): Ground truth values of shape (N,C,H,W)

As output of forward and compute the metric returns the following output

  • psnr (Tensor): if reduction!='none' returns float scalar tensor with average PSNR value over sample else returns tensor of shape (N,) with PSNR values per sample

Parameters:
  • data_range (Union[float, Tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. The data_range must be given when dim is not None.

  • base (float) – a base of a logarithm to use.

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • dim (Union[int, Tuple[int, ...], None]) – Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions and all batches.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If dim is not None and data_range is not given.

Example

>>> from torchmetrics.image import PeakSignalNoiseRatio
>>> psnr = PeakSignalNoiseRatio()
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(preds, target)
tensor(2.5527)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatio
>>> metric = PeakSignalNoiseRatio()
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/peak_signal_noise_ratio-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatio
>>> metric = PeakSignalNoiseRatio()
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/peak_signal_noise_ratio-2.png

Functional Interface

torchmetrics.functional.image.peak_signal_noise_ratio(preds, target, data_range=None, base=10.0, reduction='elementwise_mean', dim=None)[source]

Compute the peak signal-to-noise ratio.

Parameters:
  • preds (Tensor) – estimated signal

  • target (Tensor) – groun truth signal

  • data_range (Union[float, Tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. The data_range must be given when dim is not None.

  • base (float) – a base of a logarithm to use

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None``: no reduction will be applied

  • dim (Union[int, Tuple[int, ...], None]) – Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions.

Return type:

Tensor

Returns:

Tensor with PSNR score

Raises:

ValueError – If dim is not None and data_range is not provided.

Example

>>> from torchmetrics.functional.image import peak_signal_noise_ratio
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> peak_signal_noise_ratio(pred, target)
tensor(2.5527)

Note

Half precision is only support on GPU for this metric

Peak Signal To Noise Ratio With Blocked Effect

Module Interface

class torchmetrics.image.PeakSignalNoiseRatioWithBlockedEffect(block_size=8, **kwargs)[source]

Computes Peak Signal to Noise Ratio With Blocked Effect (PSNRB).

\[\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)\]

Where \(\text{MSE}\) denotes the mean-squared-error function. This metric is a modified version of PSNR that better supports evaluation of images with blocked artifacts, that oftens occur in compressed images.

Note

Metric only supports grayscale images. If you have RGB images, please convert them to grayscale first.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,1,H,W)

  • target (Tensor): Ground truth values of shape (N,1,H,W)

As output of forward and compute the metric returns the following output

  • psnrb (Tensor): float scalar tensor with aggregated PSNRB value

Parameters:

Example

>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(2, 1, 10, 10)
>>> target = torch.rand(2, 1, 10, 10)
>>> metric(preds, target)
tensor(7.2893)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> metric.update(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10))
>>> fig_, ax_ = metric.plot()
_images/peak_signal_to_noise_with_block-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10)))
>>> fig_, ax_ = metric.plot(values)
_images/peak_signal_to_noise_with_block-2.png

Functional Interface

torchmetrics.functional.image.peak_signal_noise_ratio_with_blocked_effect(preds, target, block_size=8)[source]

Computes Peak Signal to Noise Ratio With Blocked Effect (PSNRB) metrics.

\[\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)\]

Where \(\text{MSE}\) denotes the mean-squared-error function.

Parameters:
  • preds (Tensor) – estimated signal

  • target (Tensor) – groun truth signal

  • block_size (int) – integer indication the block size

Return type:

Tensor

Returns:

Tensor with PSNRB score

Example

>>> import torch
>>> from torchmetrics.functional.image import peak_signal_noise_ratio_with_blocked_effect
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand(1, 1, 28, 28)
>>> target = torch.rand(1, 1, 28, 28)
>>> peak_signal_noise_ratio_with_blocked_effect(preds, target)
tensor(7.8402)

Perceptual Path Length (PPL)

Module Interface

class torchmetrics.image.perceptual_path_length.PerceptualPathLength(num_samples=10000, conditional=False, batch_size=128, interpolation_method='lerp', epsilon=0.0001, resize=64, lower_discard=0.01, upper_discard=0.99, sim_net='vgg', **kwargs)[source]

Computes the perceptual path length (PPL) of a generator model.

The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is defined as

\[PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]\]

where \(G\) is the generator, \(I\) is the interpolation function, \(D\) is a similarity metric, \(z_1\) and \(z_2\) are two sets of latent points, and \(t\) is a parameter between 0 and 1. The metric thus works by interpolating between two sets of latent points, and measuring the similarity between the generated images. The expectation is approximated by sampling \(z_1\) and \(z_2\) from the generator, and averaging the calculated distanced. The similarity metric \(D\) is by default the LPIPS metric, but can be changed by setting the sim_net argument.

The provided generator model must have a sample method with signature sample(num_samples: int) -> Tensor where the returned tensor has shape (num_samples, z_size). If the generator is conditional, it must also have a num_classes attribute. The forward method of the generator must have signature forward(z: Tensor) -> Tensor if conditional=False, and forward(z: Tensor, labels: Tensor) -> Tensor if conditional=True. The returned tensor should have shape (num_samples, C, H, W) and be scaled to the range [0, 255].

Note

using this metric with the default feature extractor requires that torchvision is installed. Either install as pip install torchmetrics[image] or pip install torchvision

As input to forward and update the metric accepts the following input

  • generator (Module): Generator model, with specific requirements. See above.

As output of forward and compute the metric returns the following output

  • ppl_mean (Tensor): float scalar tensor with mean PPL value over distances

  • ppl_std (Tensor): float scalar tensor with std PPL value over distances

  • ppl_raw (Tensor): float scalar tensor with raw PPL distances

Parameters:
  • num_samples (int) – Number of samples to use for the PPL computation.

  • conditional (bool) – Whether the generator is conditional or not (i.e. whether it takes labels as input).

  • batch_size (int) – Batch size to use for the PPL computation.

  • interpolation_method (Literal['lerp', 'slerp_any', 'slerp_unit']) – Interpolation method to use. Choose from ‘lerp’, ‘slerp_any’, ‘slerp_unit’.

  • epsilon (float) – Spacing between the points on the path between latent points.

  • resize (Optional[int]) – Resize images to this size before computing the similarity between generated images.

  • lower_discard (Optional[float]) – Lower quantile to discard from the distances, before computing the mean and standard deviation.

  • upper_discard (Optional[float]) – Upper quantile to discard from the distances, before computing the mean and standard deviation.

  • sim_net (Union[Module, Literal['alex', 'vgg', 'squeeze']]) – Similarity network to use. Can be a nn.Module or one of ‘alex’, ‘vgg’, ‘squeeze’, where the three latter options correspond to the pretrained networks from the LPIPS paper.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ModuleNotFoundError – If torch-fidelity is not installed.

  • ValueError – If num_samples is not a positive integer.

  • ValueError – If conditional is not a boolean.

  • ValueError – If batch_size is not a positive integer.

  • ValueError – If interpolation_method is not one of ‘lerp’, ‘slerp_any’, ‘slerp_unit’.

  • ValueError – If epsilon is not a positive float.

  • ValueError – If resize is not a positive integer.

  • ValueError – If lower_discard is not a float between 0 and 1 or None.

  • ValueError – If upper_discard is not a float between 0 and 1 or None.

Example::
>>> from torchmetrics.image import PerceptualPathLength
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
...    def __init__(self, z_size) -> None:
...       super().__init__()
...       self.z_size = z_size
...       self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid())
...    def forward(self, z):
...       return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1)
...    def sample(self, num_samples):
...      return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> ppl = PerceptualPathLength(num_samples=10)
>>> ppl(generator)  
(tensor(0.2371),
tensor(0.1763),
tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921]))
class torchmetrics.image.perceptual_path_length.GeneratorType(*args, **kwargs)[source]

Basic interface for a generator model.

Users can inherit from this class and implement their own generator model. The requirements are that the sample method is implemented and that the num_classes attribute is present when conditional=True metric.

sample(num_samples)[source]

Sample from the generator.

Parameters:

num_samples (int) – Number of samples to generate.

Return type:

Tensor

property num_classes: int

Return the number of classes for conditional generation.

Functional Interface

torchmetrics.functional.image.perceptual_path_length.perceptual_path_length(generator, num_samples=10000, conditional=False, batch_size=64, interpolation_method='lerp', epsilon=0.0001, resize=64, lower_discard=0.01, upper_discard=0.99, sim_net='vgg', device='cpu')[source]

Computes the perceptual path length (PPL) of a generator model.

The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is defined as

\[PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]\]

where \(G\) is the generator, \(I\) is the interpolation function, \(D\) is a similarity metric, \(z_1\) and \(z_2\) are two sets of latent points, and \(t\) is a parameter between 0 and 1. The metric thus works by interpolating between two sets of latent points, and measuring the similarity between the generated images. The expectation is approximated by sampling \(z_1\) and \(z_2\) from the generator, and averaging the calculated distanced. The similarity metric \(D\) is by default the LPIPS metric, but can be changed by setting the sim_net argument.

The provided generator model must have a sample method with signature sample(num_samples: int) -> Tensor where the returned tensor has shape (num_samples, z_size). If the generator is conditional, it must also have a num_classes attribute. The forward method of the generator must have signature forward(z: Tensor) -> Tensor if conditional=False, and forward(z: Tensor, labels: Tensor) -> Tensor if conditional=True. The returned tensor should have shape (num_samples, C, H, W) and be scaled to the range [0, 255].

Parameters:
  • generator (GeneratorType) – Generator model, with specific requirements. See above.

  • num_samples (int) – Number of samples to use for the PPL computation.

  • conditional (bool) – Whether the generator is conditional or not (i.e. whether it takes labels as input).

  • batch_size (int) – Batch size to use for the PPL computation.

  • interpolation_method (Literal['lerp', 'slerp_any', 'slerp_unit']) – Interpolation method to use. Choose from ‘lerp’, ‘slerp_any’, ‘slerp_unit’.

  • epsilon (float) – Spacing between the points on the path between latent points.

  • resize (Optional[int]) – Resize images to this size before computing the similarity between generated images.

  • lower_discard (Optional[float]) – Lower quantile to discard from the distances, before computing the mean and standard deviation.

  • upper_discard (Optional[float]) – Upper quantile to discard from the distances, before computing the mean and standard deviation.

  • sim_net (Union[Module, Literal['alex', 'vgg', 'squeeze']]) – Similarity network to use. Can be a nn.Module or one of ‘alex’, ‘vgg’, ‘squeeze’, where the three latter options correspond to the pretrained networks from the LPIPS paper.

  • device (Union[str, device]) – Device to use for the computation.

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

A tuple containing the mean, standard deviation and all distances.

Example::
>>> from torchmetrics.functional.image import perceptual_path_length
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
...    def __init__(self, z_size) -> None:
...       super().__init__()
...       self.z_size = z_size
...       self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid())
...    def forward(self, z):
...       return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1)
...    def sample(self, num_samples):
...      return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> perceptual_path_length(generator, num_samples=10)  
(tensor(0.1945),
tensor(0.1222),
tensor([0.0990, 0.4173, 0.1628, 0.3573, 0.1875, 0.0335, 0.1095, 0.1887, 0.1953]))

Relative Average Spectral Error (RASE)

Module Interface

class torchmetrics.image.RelativeAverageSpectralError(window_size=8, **kwargs)[source]

Computes Relative Average Spectral Error (RASE) (RelativeAverageSpectralError).

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,C,H,W)

  • target (Tensor): Ground truth values of shape (N,C,H,W)

As output of forward and compute the metric returns the following output

  • rase (Tensor): returns float scalar tensor with average RASE value over sample

Parameters:
Returns:

Relative Average Spectral Error (RASE)

Example

>>> import torch
>>> from torchmetrics.image import RelativeAverageSpectralError
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> rase = RelativeAverageSpectralError()
>>> rase(preds, target)
tensor(5114.6641)
Raises:

ValueError – If window_size is not a positive integer.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import RelativeAverageSpectralError
>>> metric = RelativeAverageSpectralError()
>>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))
>>> fig_, ax_ = metric.plot()
_images/relative_average_spectral_error-1.png
>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import RelativeAverageSpectralError
>>> metric = RelativeAverageSpectralError()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)))
>>> fig_, ax_ = metric.plot(values)
_images/relative_average_spectral_error-2.png

Functional Interface

torchmetrics.functional.image.relative_average_spectral_error(preds, target, window_size=8)[source]

Compute Relative Average Spectral Error (RASE) (RelativeAverageSpectralError).

Parameters:
  • preds (Tensor) – Deformed image

  • target (Tensor) – Ground truth image

  • window_size (int) – Sliding window used for rmse calculation

Return type:

Tensor

Returns:

Relative Average Spectral Error (RASE)

Example

>>> from torchmetrics.functional.image import relative_average_spectral_error
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> relative_average_spectral_error(preds, target)
tensor(5114.6641)
Raises:

ValueError – If window_size is not a positive integer.

Root Mean Squared Error Using Sliding Window

Module Interface

class torchmetrics.image.RootMeanSquaredErrorUsingSlidingWindow(window_size=8, **kwargs)[source]

Computes Root Mean Squared Error (RMSE) using sliding window.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,C,H,W)

  • target (Tensor): Ground truth values of shape (N,C,H,W)

As output of forward and compute the metric returns the following output

  • rmse_sw (Tensor): returns float scalar tensor with average RMSE-SW value over sample

Parameters:

Example

>>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow()
>>> rmse_sw(preds, target)
tensor(0.3999)
Raises:

ValueError – If window_size is not a positive integer.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow
>>> metric = RootMeanSquaredErrorUsingSlidingWindow()
>>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))
>>> fig_, ax_ = metric.plot()
_images/root_mean_squared_error_using_sliding_window-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow
>>> metric = RootMeanSquaredErrorUsingSlidingWindow()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)))
>>> fig_, ax_ = metric.plot(values)
_images/root_mean_squared_error_using_sliding_window-2.png

Functional Interface

torchmetrics.functional.image.root_mean_squared_error_using_sliding_window(preds, target, window_size=8, return_rmse_map=False)[source]

Compute Root Mean Squared Error (RMSE) using sliding window.

Parameters:
  • preds (Tensor) – Deformed image

  • target (Tensor) – Ground truth image

  • window_size (int) – Sliding window used for rmse calculation

  • return_rmse_map (bool) – An indication whether the full rmse reduced image should be returned.

Return type:

Union[Tensor, None, Tuple[Optional[Tensor], Tensor]]

Returns:

RMSE using sliding window (Optionally) RMSE map

Example

>>> from torchmetrics.functional.image import root_mean_squared_error_using_sliding_window
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> root_mean_squared_error_using_sliding_window(preds, target)
tensor(0.3999)
Raises:

ValueError – If window_size is not a positive integer.

Spectral Angle Mapper

Module Interface

class torchmetrics.image.SpectralAngleMapper(reduction='elementwise_mean', **kwargs)[source]

Spectral Angle Mapper determines the spectral similarity between image spectra and reference spectra.

It works by calculating the angle between the spectra, where small angles between indicate high similarity and high angles indicate low similarity.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,C,H,W)

  • target (Tensor): Ground truth values of shape (N,C,H,W)

As output of forward and compute the metric returns the following output

  • sam (Tensor): if reduction!='none' returns float scalar tensor with average SAM value over sample else returns tensor of shape (N,) with SAM values per sample

Parameters:
  • reduction (Literal['elementwise_mean', 'sum', 'none']) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

Tensor with SpectralAngleMapper score

Example

>>> import torch
>>> from torchmetrics.image import SpectralAngleMapper
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> sam = SpectralAngleMapper()
>>> sam(preds, target)
tensor(0.5914)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting single value
>>> import torch
>>> from torchmetrics.image import SpectralAngleMapper
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> metric = SpectralAngleMapper()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/spectral_angle_mapper-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import SpectralAngleMapper
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> metric = SpectralAngleMapper()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/spectral_angle_mapper-2.png

Functional Interface

torchmetrics.functional.image.spectral_angle_mapper(preds, target, reduction='elementwise_mean')[source]

Universal Spectral Angle Mapper.

Parameters:
  • preds (Tensor) – estimated image

  • target (Tensor) – ground truth image

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

Return type:

Tensor

Returns:

Tensor with Spectral Angle Mapper score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

Example

>>> from torchmetrics.functional.image import spectral_angle_mapper
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16], generator=gen)
>>> target = torch.rand([16, 3, 16, 16], generator=gen)
>>> spectral_angle_mapper(preds, target)
tensor(0.5914)

References

[1] Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, “Discrimination among semi-arid landscape endmembers using the Spectral Angle Mapper (SAM) algorithm” in PL, Summaries of the Third Annual JPL Airborne Geoscience Workshop, vol. 1, June 1, 1992.

Spectral Distortion Index

Module Interface

class torchmetrics.image.SpectralDistortionIndex(p=1, reduction='elementwise_mean', **kwargs)[source]

Compute Spectral Distortion Index (SpectralDistortionIndex) also now as D_lambda.

The metric is used to compare the spectral distortion between two images.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Low resolution multispectral image of shape (N,C,H,W)

  • target``(:class:`~torch.Tensor`): High resolution fused image of shape ``(N,C,H,W)

As output of forward and compute the metric returns the following output

  • sdi (Tensor): if reduction!='none' returns float scalar tensor with average SDI value over sample else returns tensor of shape (N,) with SDI values per sample

Parameters:
  • p (int) – Large spectral differences

  • reduction (Literal['elementwise_mean', 'sum', 'none']) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none': no reduction will be applied

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> sdi = SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/spectral_distortion_index-1.png
>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpectralDistortionIndex
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> metric = SpectralDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/spectral_distortion_index-2.png

Functional Interface

torchmetrics.functional.image.spectral_distortion_index(preds, target, p=1, reduction='elementwise_mean')[source]

Calculate Spectral Distortion Index (SpectralDistortionIndex) also known as D_lambda.

Metric is used to compare the spectral distortion between two images.

Parameters:
  • preds (Tensor) – Low resolution multispectral image

  • target (Tensor) – High resolution fused image

  • p (int) – Large spectral differences

  • reduction (Literal['elementwise_mean', 'sum', 'none']) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none': no reduction will be applied

Return type:

Tensor

Returns:

Tensor with SpectralDistortionIndex score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

  • ValueError – If p is not a positive integer.

Example

>>> from torchmetrics.functional.image import spectral_distortion_index
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> spectral_distortion_index(preds, target)
tensor(0.0234)

Structural Similarity Index Measure (SSIM)

Module Interface

class torchmetrics.image.StructuralSimilarityIndexMeasure(gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, return_full_image=False, return_contrast_sensitivity=False, **kwargs)[source]

Compute Structural Similarity Index Measure (SSIM).

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output

  • ssim (Tensor): if reduction!='none' returns float scalar tensor with average SSIM value over sample else returns tensor of shape (N,) with SSIM values per sample

Parameters:
  • preds – estimated image

  • target – ground truth image

  • gaussian_kernel (bool) – If True (default), a gaussian kernel is used, if False a uniform kernel is used

  • sigma (Union[float, Sequence[float]]) – Standard deviation of the gaussian kernel, anisotropic kernels are possible. Ignored if a uniform kernel is used

  • kernel_size (Union[int, Sequence[int]]) – the size of the uniform kernel, anisotropic kernels are possible. Ignored if a Gaussian kernel is used

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over individual batch scores

    • 'elementwise_mean': takes the mean

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • data_range (Union[float, Tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.

  • k1 (float) – Parameter of SSIM.

  • k2 (float) – Parameter of SSIM.

  • return_full_image (bool) – If true, the full ssim image is returned as a second argument. Mutually exclusive with return_contrast_sensitivity

  • return_contrast_sensitivity (bool) – If true, the constant term is returned as a second argument. The luminance term can be obtained with luminance=ssim/contrast Mutually exclusive with return_full_image

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> from torchmetrics.image import StructuralSimilarityIndexMeasure
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
>>> ssim(preds, target)
tensor(0.9219)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import StructuralSimilarityIndexMeasure
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> metric = StructuralSimilarityIndexMeasure(data_range=1.0)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/structural_similarity-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import StructuralSimilarityIndexMeasure
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> metric = StructuralSimilarityIndexMeasure(data_range=1.0)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/structural_similarity-2.png

Functional Interface

torchmetrics.functional.image.structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, return_full_image=False, return_contrast_sensitivity=False)[source]

Compute Structual Similarity Index Measure.

Parameters:
  • preds (Tensor) – estimated image

  • target (Tensor) – ground truth image

  • gaussian_kernel (bool) – If true (default), a gaussian kernel is used, if false a uniform kernel is used

  • sigma (Union[float, Sequence[float]]) – Standard deviation of the gaussian kernel, anisotropic kernels are possible. Ignored if a uniform kernel is used

  • kernel_size (Union[int, Sequence[int]]) – the size of the uniform kernel, anisotropic kernels are possible. Ignored if a Gaussian kernel is used

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • data_range (Union[float, Tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.

  • k1 (float) – Parameter of SSIM.

  • k2 (float) – Parameter of SSIM.

  • return_full_image (bool) – If true, the full ssim image is returned as a second argument. Mutually exclusive with return_contrast_sensitivity

  • return_contrast_sensitivity (bool) – If true, the constant term is returned as a second argument. The luminance term can be obtained with luminance=ssim/contrast Mutually exclusive with return_full_image

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

Tensor with SSIM score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

  • ValueError – If the length of kernel_size or sigma is not 2.

  • ValueError – If one of the elements of kernel_size is not an odd positive number.

  • ValueError – If one of the elements of sigma is not a positive number.

Example

>>> from torchmetrics.functional.image import structural_similarity_index_measure
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> structural_similarity_index_measure(preds, target)
tensor(0.9219)

Total Variation (TV)

Module Interface

class torchmetrics.image.TotalVariation(reduction='sum', **kwargs)[source]

Compute Total Variation loss (TV).

As input to forward and update the metric accepts the following input

  • img (Tensor): A tensor of shape (N, C, H, W) consisting of images

As output of forward and compute the metric returns the following output

  • sdi (Tensor): if reduction!='none' returns float scalar tensor with average TV value over sample else returns tensor of shape (N,) with TV values per sample

Parameters:
  • reduction (Optional[Literal['mean', 'sum', 'none']]) –

    a method to reduce metric score over samples

    • 'mean': takes the mean over samples

    • 'sum': takes the sum over samples

    • None or 'none': return the score per sample

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If reduction is not one of 'sum', 'mean', 'none' or None

Example

>>> import torch
>>> from torchmetrics.image import TotalVariation
>>> _ = torch.manual_seed(42)
>>> tv = TotalVariation()
>>> img = torch.rand(5, 3, 28, 28)
>>> tv(img)
tensor(7546.8018)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import TotalVariation
>>> metric = TotalVariation()
>>> metric.update(torch.rand(5, 3, 28, 28))
>>> fig_, ax_ = metric.plot()
_images/total_variation-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import TotalVariation
>>> metric = TotalVariation()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(5, 3, 28, 28)))
>>> fig_, ax_ = metric.plot(values)
_images/total_variation-2.png

Functional Interface

torchmetrics.functional.image.total_variation(img, reduction='sum')[source]

Compute total variation loss.

Parameters:
  • img (Tensor) – A Tensor of shape (N, C, H, W) consisting of images

  • reduction (Optional[Literal['mean', 'sum', 'none']]) –

    a method to reduce metric score over samples.

    • 'mean': takes the mean over samples

    • 'sum': takes the sum over samples

    • None or 'none': return the score per sample

Return type:

Tensor

Returns:

A loss scalar value containing the total variation

Raises:
  • ValueError – If reduction is not one of 'sum', 'mean', 'none' or None

  • RuntimeError – If img is not 4D tensor

Example

>>> import torch
>>> from torchmetrics.functional.image import total_variation
>>> _ = torch.manual_seed(42)
>>> img = torch.rand(5, 3, 28, 28)
>>> total_variation(img)
tensor(7546.8018)

Universal Image Quality Index

Module Interface

class torchmetrics.image.UniversalImageQualityIndex(kernel_size=(11, 11), sigma=(1.5, 1.5), reduction='elementwise_mean', **kwargs)[source]

Compute Universal Image Quality Index (UniversalImageQualityIndex).

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,C,H,W)

  • target (Tensor): Ground truth values of shape (N,C,H,W)

As output of forward and compute the metric returns the following output

  • uiqi (Tensor): if reduction!='none' returns float scalar tensor with average UIQI value over sample else returns tensor of shape (N,) with UIQI values per sample

Parameters:
  • kernel_size (Sequence[int]) – size of the gaussian kernel

  • sigma (Sequence[float]) – Standard deviation of the gaussian kernel

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

Tensor with UniversalImageQualityIndex score

Example

>>> import torch
>>> from torchmetrics.image import UniversalImageQualityIndex
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> uqi = UniversalImageQualityIndex()
>>> uqi(preds, target)
tensor(0.9216)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image import UniversalImageQualityIndex
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> metric = UniversalImageQualityIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/universal_image_quality_index-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image import UniversalImageQualityIndex
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> metric = UniversalImageQualityIndex()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/universal_image_quality_index-2.png

Functional Interface

torchmetrics.functional.image.universal_image_quality_index(preds, target, kernel_size=(11, 11), sigma=(1.5, 1.5), reduction='elementwise_mean')[source]

Universal Image Quality Index.

Parameters:
  • preds (Tensor) – estimated image

  • target (Tensor) – ground truth image

  • kernel_size (Sequence[int]) – size of the gaussian kernel

  • sigma (Sequence[float]) – Standard deviation of the gaussian kernel

  • reduction (Optional[Literal['elementwise_mean', 'sum', 'none']]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean (default)

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

Return type:

Tensor

Returns:

Tensor with UniversalImageQualityIndex score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

  • ValueError – If the length of kernel_size or sigma is not 2.

  • ValueError – If one of the elements of kernel_size is not an odd positive number.

  • ValueError – If one of the elements of sigma is not a positive number.

Example

>>> from torchmetrics.functional.image import universal_image_quality_index
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> universal_image_quality_index(preds, target)
tensor(0.9216)

References

[1] Zhou Wang and A. C. Bovik, “A universal image quality index,” in IEEE Signal Processing Letters, vol. 9, no. 3, pp. 81-84, March 2002, doi: 10.1109/97.995823.

[2] Zhou Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: from error visibility to structural similarity,” in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, April 2004, doi: 10.1109/TIP.2003.819861.

Visual Information Fidelity (VIF)

Module Interface

class torchmetrics.image.VisualInformationFidelity(sigma_n_sq=2.0, **kwargs)[source]

Compute Pixel Based Visual Information Fidelity (VIF).

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model of shape (N,C,H,W) with H,W ≥ 41

  • target (Tensor): Ground truth values of shape (N,C,H,W) with H,W ≥ 41

As output of forward and compute the metric returns the following output

  • vif-p (Tensor): Tensor with vif-p score

Parameters:

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import VisualInformationFidelity
>>> preds = torch.randn([32, 3, 41, 41])
>>> target = torch.randn([32, 3, 41, 41])
>>> vif = VisualInformationFidelity()
>>> vif(preds, target)
tensor(0.0032)

Functional Interface

torchmetrics.functional.image.visual_information_fidelity(preds, target, sigma_n_sq=2.0)[source]

Compute Pixel Based Visual Information Fidelity (VIF).

Parameters:
  • preds (Tensor) – predicted images of shape (N,C,H,W). (H, W) has to be at least (41, 41).

  • target (Tensor) – ground truth images of shape (N,C,H,W). (H, W) has to be at least (41, 41)

  • sigma_n_sq (float) – variance of the visual noise

Return type:

Tensor

Returns:

Tensor with vif-p score

Raises:

ValueError – If data_range is neither a tuple nor a float

CLIP Image Quality Assessment (CLIP-IQA)

Module Interface

class torchmetrics.multimodal.CLIPImageQualityAssessment(model_name_or_path='clip_iqa', data_range=1.0, prompts=('quality',), **kwargs)[source]

Calculates CLIP-IQA, that can be used to measure the visual content of images.

The metric is based on the CLIP model, which is a neural network trained on a variety of (image, text) pairs to be able to generate a vector representation of the image and the text that is similar if the image and text are semantically similar.

The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The promts always comes in pairs of “positive” and “negative” such as “Good photo.” and “Bad photo.”. By calculating the similartity between image embeddings and both the “positive” and “negative” prompt, the metric can determine which prompt the image is more similar to. The metric then returns the probability that the image is more similar to the first prompt than the second prompt.

Build in promts are:
  • quality: “Good photo.” vs “Bad photo.”

  • brightness: “Bright photo.” vs “Dark photo.”

  • noisiness: “Clean photo.” vs “Noisy photo.”

  • colorfullness: “Colorful photo.” vs “Dull photo.”

  • sharpness: “Sharp photo.” vs “Blurry photo.”

  • contrast: “High contrast photo.” vs “Low contrast photo.”

  • complexity: “Complex photo.” vs “Simple photo.”

  • natural: “Natural photo.” vs “Synthetic photo.”

  • happy: “Happy photo.” vs “Sad photo.”

  • scary: “Scary photo.” vs “Peaceful photo.”

  • new: “New photo.” vs “Old photo.”

  • warm: “Warm photo.” vs “Cold photo.”

  • real: “Real photo.” vs “Abstract photo.”

  • beutiful: “Beautiful photo.” vs “Ugly photo.”

  • lonely: “Lonely photo.” vs “Sociable photo.”

  • relaxing: “Relaxing photo.” vs “Stressful photo.”

As input to forward and update the metric accepts the following input

  • images (Tensor): tensor with images feed to the feature extractor with shape (N,C,H,W)

As output of forward and compute the metric returns the following output

  • clip_iqa (Tensor or dict of tensors): tensor with the CLIP-IQA score. If a single prompt is provided, a single tensor with shape (N,) is returned. If a list of prompts is provided, a dict of tensors is returned with the prompt as key and the tensor with shape (N,) as value.

Parameters:
  • model_name_or_path (Literal['clip_iqa', 'openai/clip-vit-base-patch16', 'openai/clip-vit-base-patch32', 'openai/clip-vit-large-patch14-336', 'openai/clip-vit-large-patch14']) –

    string indicating the version of the CLIP model to use. Available models are:

    • ”clip_iqa”, model corresponding to the CLIP-IQA paper.

    • ”openai/clip-vit-base-patch16”

    • ”openai/clip-vit-base-patch32”

    • ”openai/clip-vit-large-patch14-336”

    • ”openai/clip-vit-large-patch14”

  • data_range (Union[int, float]) – The maximum value of the input tensor. For example, if the input images are in range [0, 255], data_range should be 255. The images are normalized by this value.

  • prompts (Tuple[Union[str, Tuple[str, str]]]) – A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Note

If using the default clip_iqa model, the package piq must be installed. Either install with pip install piq or pip install torchmetrics[image].

Raises:
  • ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0

  • ValueError – If prompts is a tuple and it is not of length 2

  • ValueError – If prompts is a string and it is not one of the available prompts

  • ValueError – If prompts is a list of strings and not all strings are one of the available prompts

Example::

Single prompt:

>>> from torchmetrics.multimodal import CLIPImageQualityAssessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> metric = CLIPImageQualityAssessment()
>>> metric(imgs)
tensor([0.8894, 0.8902])
Example::

Multiple prompts:

>>> from torchmetrics.multimodal import CLIPImageQualityAssessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness"))
>>> metric(imgs)
{'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])}
Example::

Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.

>>> from torchmetrics.multimodal import CLIPImageQualityAssessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness"))
>>> metric(imgs)
{'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])}
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment
>>> metric = CLIPImageQualityAssessment()
>>> metric.update(torch.rand(1, 3, 224, 224))
>>> fig_, ax_ = metric.plot()
_images/clip_iqa-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment
>>> metric = CLIPImageQualityAssessment()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(1, 3, 224, 224)))
>>> fig_, ax_ = metric.plot(values)
_images/clip_iqa-2.png

Functional Interface

torchmetrics.functional.multimodal.clip_image_quality_assessment(images, model_name_or_path='clip_iqa', data_range=1.0, prompts=('quality',))[source]

Calculates CLIP-IQA, that can be used to measure the visual content of images.

The metric is based on the CLIP model, which is a neural network trained on a variety of (image, text) pairs to be able to generate a vector representation of the image and the text that is similar if the image and text are semantically similar.

The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The prompts always come in pairs of “positive” and “negative” such as “Good photo.” and “Bad photo.”. By calculating the similartity between image embeddings and both the “positive” and “negative” prompt, the metric can determine which prompt the image is more similar to. The metric then returns the probability that the image is more similar to the first prompt than the second prompt.

Build in promts are:
  • quality: “Good photo.” vs “Bad photo.”

  • brightness: “Bright photo.” vs “Dark photo.”

  • noisiness: “Clean photo.” vs “Noisy photo.”

  • colorfullness: “Colorful photo.” vs “Dull photo.”

  • sharpness: “Sharp photo.” vs “Blurry photo.”

  • contrast: “High contrast photo.” vs “Low contrast photo.”

  • complexity: “Complex photo.” vs “Simple photo.”

  • natural: “Natural photo.” vs “Synthetic photo.”

  • happy: “Happy photo.” vs “Sad photo.”

  • scary: “Scary photo.” vs “Peaceful photo.”

  • new: “New photo.” vs “Old photo.”

  • warm: “Warm photo.” vs “Cold photo.”

  • real: “Real photo.” vs “Abstract photo.”

  • beutiful: “Beautiful photo.” vs “Ugly photo.”

  • lonely: “Lonely photo.” vs “Sociable photo.”

  • relaxing: “Relaxing photo.” vs “Stressful photo.”

Parameters:
  • images (Tensor) – Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors

  • model_name_or_path (Literal['clip_iqa', 'openai/clip-vit-base-patch16', 'openai/clip-vit-base-patch32', 'openai/clip-vit-large-patch14-336', 'openai/clip-vit-large-patch14']) – string indicating the version of the CLIP model to use. By default this argument is set to clip_iqa which corresponds to the model used in the original paper. Other availble models are “openai/clip-vit-base-patch16”, “openai/clip-vit-base-patch32”, “openai/clip-vit-large-patch14-336” and “openai/clip-vit-large-patch14”

  • data_range (Union[int, float]) – The maximum value of the input tensor. For example, if the input images are in range [0, 255], data_range should be 255. The images are normalized by this value.

  • prompts (Tuple[Union[str, Tuple[str, str]]]) – A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt.

Note

If using the default clip_iqa model, the package piq must be installed. Either install with pip install piq or pip install torchmetrics[multimodal].

Return type:

Union[Tensor, Dict[str, Tensor]]

Returns:

A tensor of shape (N,) if a single promts is provided. If a list of promts is provided, a dictionary of with the promts as keys and tensors of shape (N,) as values.

Raises:
  • ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0

  • ValueError – If not all images have format [C, H, W]

  • ValueError – If promts is a tuple and it is not of length 2

  • ValueError – If promts is a string and it is not one of the available promts

  • ValueError – If promts is a list of strings and not all strings are one of the available promts

Example::

Single promt:

>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> clip_image_quality_assessment(imgs, prompts=("quality",))
tensor([0.8894, 0.8902])
Example::

Multiple promts:

>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness"))
{'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])}
Example::

Custom promts. Must always be a tuple of length 2, with a positive and negative prompt.

>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment
>>> import torch
>>> _ = torch.manual_seed(42)
>>> imgs = torch.randint(255, (2, 3, 224, 224)).float()
>>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness"))
{'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])}

CLIP Score

Module Interface

class torchmetrics.multimodal.clip_score.CLIPScore(model_name_or_path='openai/clip-vit-large-patch14', **kwargs)[source]

Calculates CLIP Score which is a text-to-image similarity metric.

CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for an image and the actual content of the image. It has been found to be highly correlated with human judgement. The metric is defined as:

\[\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)\]

which corresponds to the cosine similarity between visual CLIP embedding \(E_i\) for an image \(i\) and textual CLIP embedding \(E_C\) for an caption \(C\). The score is bound between 0 and 100 and the closer to 100 the better.

Note

Metric is not scriptable

As input to forward and update the metric accepts the following input

  • images (Tensor or list of tensors): tensor with images feed to the feature extractor with. If

    a single tensor it should have shape (N, C, H, W). If a list of tensors, each tensor should have shape (C, H, W). C is the number of channels, H and W are the height and width of the image.

  • text (str or list of str): text to compare with the images, one for each image.

As output of forward and compute the metric returns the following output

  • clip_score (Tensor): float scalar tensor with mean CLIP score over samples

Parameters:
  • model_name_or_path (Literal['openai/clip-vit-base-patch16', 'openai/clip-vit-base-patch32', 'openai/clip-vit-large-patch14-336', 'openai/clip-vit-large-patch14']) –

    string indicating the version of the CLIP model to use. Available models are:

    • ”openai/clip-vit-base-patch16”

    • ”openai/clip-vit-base-patch32”

    • ”openai/clip-vit-large-patch14-336”

    • ”openai/clip-vit-large-patch14”

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0

Example

>>> import torch
>>> from torchmetrics.multimodal.clip_score import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> score = metric(torch.randint(255, (3, 224, 224), generator=torch.manual_seed(42)), "a photo of a cat")
>>> score.detach()
tensor(24.4255)
compute()[source]

Compute accumulated clip score.

Return type:

Tensor

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.multimodal.clip_score import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> metric.update(torch.randint(255, (3, 224, 224)), "a photo of a cat")
>>> fig_, ax_ = metric.plot()
_images/clip_score-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.multimodal.clip_score import CLIPScore
>>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randint(255, (3, 224, 224)), "a photo of a cat"))
>>> fig_, ax_ = metric.plot(values)
_images/clip_score-2.png
update(images, text)[source]

Update CLIP score on a batch of images and text.

Parameters:
  • images (Union[Tensor, List[Tensor]]) – Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors

  • text (Union[str, List[str]]) – Either a single caption or a list of captions

Raises:
  • ValueError – If not all images have format [C, H, W]

  • ValueError – If the number of images and captions do not match

Return type:

None

Functional Interface

torchmetrics.functional.multimodal.clip_score.clip_score(images, text, model_name_or_path='openai/clip-vit-large-patch14')[source]

Calculate CLIP Score which is a text-to-image similarity metric.

CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for an image and the actual content of the image. It has been found to be highly correlated with human judgement. The metric is defined as:

\[\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)\]

which corresponds to the cosine similarity between visual CLIP embedding \(E_i\) for an image \(i\) and textual CLIP embedding \(E_C\) for an caption \(C\). The score is bound between 0 and 100 and the closer to 100 the better.

Note

Metric is not scriptable

Parameters:
  • images (Union[Tensor, List[Tensor]]) – Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors

  • text (Union[str, List[str]]) – Either a single caption or a list of captions

  • model_name_or_path (Literal['openai/clip-vit-base-patch16', 'openai/clip-vit-base-patch32', 'openai/clip-vit-large-patch14-336', 'openai/clip-vit-large-patch14']) – string indicating the version of the CLIP model to use. Available models are “openai/clip-vit-base-patch16”, “openai/clip-vit-base-patch32”, “openai/clip-vit-large-patch14-336” and “openai/clip-vit-large-patch14”,

Raises:
  • ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0

  • ValueError – If not all images have format [C, H, W]

  • ValueError – If the number of images and captions do not match

Return type:

Tensor

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.functional.multimodal import clip_score
>>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16")
>>> score.detach()
tensor(24.4255)

Cramer’s V

Module Interface

class torchmetrics.nominal.CramersV(num_classes, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]

Compute Cramer’s V statistic measuring the association between two categorical (nominal) data series.

\[V = \sqrt{\frac{\chi^2 / n}{\min(r - 1, k - 1)}}\]

where

\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]

where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in preds and target, respectively. Cramer’s V is a symmetric coefficient, i.e. \(V(preds, target) = V(target, preds)\), so order of input arguments does not matter. The output values lies in [0, 1] with 1 meaning the perfect association.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the first data series with shape (batch_size,) or (batch_size, num_classes), respectively.

  • target (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the second data series with shape (batch_size,) or (batch_size, num_classes), respectively.

As output of forward and compute the metric returns the following output:

  • cramers_v (Tensor): Scalar tensor containing the Cramer’s V statistic.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • bias_correction (bool) – Indication of whether to use bias correction.

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If nan_strategy is not one of ‘replace’ and ‘drop’

  • ValueError – If nan_strategy is equal to ‘replace’ and nan_replace_value is not an int or float

Example:

>>> from torchmetrics.nominal import CramersV
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> cramers_v = CramersV(num_classes=5)
>>> cramers_v(preds, target)
tensor(0.5284)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.nominal import CramersV
>>> metric = CramersV(num_classes=5)
>>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))
>>> fig_, ax_ = metric.plot()
_images/cramers_v-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.nominal import CramersV
>>> metric = CramersV(num_classes=5)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))))
>>> fig_, ax_ = metric.plot(values)
_images/cramers_v-2.png

Functional Interface

torchmetrics.functional.nominal.cramers_v(preds, target, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Cramer’s V statistic measuring the association between two categorical (nominal) data series.

\[V = \sqrt{\frac{\chi^2 / n}{\min(r - 1, k - 1)}}\]

where

\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]

where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in preds and target, respectively.

Cramer’s V is a symmetric coefficient, i.e. \(V(preds, target) = V(target, preds)\).

The output values lies in [0, 1] with 1 meaning the perfect association.

Parameters:
  • preds (Tensor) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)

  • target (Tensor) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)

  • bias_correction (bool) – Indication of whether to use bias correction.

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Cramer’s V statistic

Example

>>> from torchmetrics.functional.nominal import cramers_v
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> cramers_v(preds, target)
tensor(0.5284)
cramers_v_matrix
torchmetrics.functional.nominal.cramers_v_matrix(matrix, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Cramer’s V statistic between a set of multiple variables.

This can serve as a convenient tool to compute Cramer’s V statistic for analyses of correlation between categorical variables in your dataset.

Parameters:
  • matrix (Tensor) – A tensor of categorical (nominal) data, where: - rows represent a number of data points - columns represent a number of categorical (nominal) features

  • bias_correction (bool) – Indication of whether to use bias correction.

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Cramer’s V statistic for a dataset of categorical variables

Example

>>> from torchmetrics.functional.nominal import cramers_v_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> cramers_v_matrix(matrix)
tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337],
        [0.0637, 1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000, 0.0649],
        [0.0542, 0.0000, 0.0000, 1.0000, 0.1100],
        [0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])

Fleiss Kappa

Module Interface

class torchmetrics.nominal.FleissKappa(mode='counts', **kwargs)[source]

Calculatees Fleiss kappa a statistical measure for inter agreement between raters.

\[\kappa = \frac{\bar{p} - \bar{p_e}}{1 - \bar{p_e}}\]

where \(\bar{p}\) is the mean of the agreement probability over all raters and \(\bar{p_e}\) is the mean agreement probability over all raters if they were randomly assigned. If the raters are in complete agreement then the score 1 is returned, if there is no agreement among the raters (other than what would be expected by chance) then a score smaller than 0 is returned.

As input to forward and update the metric accepts the following input:

  • ratings (Tensor): Ratings of shape [n_samples, n_categories] or [n_samples, n_categories, n_raters] depedenent on mode. If mode is counts, ratings must be integer and contain the number of raters that chose each category. If mode is probs, ratings must be floating point and contain the probability/logits that each rater chose each category.

As output of forward and compute the metric returns the following output:

  • fleiss_k (Tensor): A float scalar tensor with the calculated Fleiss’ kappa score.

Parameters:
  • mode (Literal['counts', 'probs']) – Whether ratings will be provided as counts or probabilities.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> # Ratings are provided as counts
>>> import torch
>>> from torchmetrics.nominal import FleissKappa
>>> _ = torch.manual_seed(42)
>>> ratings = torch.randint(0, 10, size=(100, 5)).long()  # 100 samples, 5 categories, 10 raters
>>> metric = FleissKappa(mode='counts')
>>> metric(ratings)
tensor(0.0089)

Example

>>> # Ratings are provided as probabilities
>>> import torch
>>> from torchmetrics.nominal import FleissKappa
>>> _ = torch.manual_seed(42)
>>> ratings = torch.randn(100, 5, 10).softmax(dim=1)  # 100 samples, 5 categories, 10 raters
>>> metric = FleissKappa(mode='probs')
>>> metric(ratings)
tensor(-0.0105)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.nominal import FleissKappa
>>> metric = FleissKappa(mode="probs")
>>> metric.update(torch.randn(100, 5, 10).softmax(dim=1))
>>> fig_, ax_ = metric.plot()
_images/fleiss_kappa-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.nominal import FleissKappa
>>> metric = FleissKappa(mode="probs")
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randn(100, 5, 10).softmax(dim=1)))
>>> fig_, ax_ = metric.plot(values)
_images/fleiss_kappa-2.png

Functional Interface

torchmetrics.functional.nominal.fleiss_kappa(ratings, mode='counts')[source]

Calculatees Fleiss kappa a statistical measure for inter agreement between raters.

\[\kappa = \frac{\bar{p} - \bar{p_e}}{1 - \bar{p_e}}\]

where \(\bar{p}\) is the mean of the agreement probability over all raters and \(\bar{p_e}\) is the mean agreement probability over all raters if they were randomly assigned. If the raters are in complete agreement then the score 1 is returned, if there is no agreement among the raters (other than what would be expected by chance) then a score smaller than 0 is returned.

Parameters:
  • ratings (Tensor) – Ratings of shape [n_samples, n_categories] or [n_samples, n_categories, n_raters] depedenent on mode. If mode is counts, ratings must be integer and contain the number of raters that chose each category. If mode is probs, ratings must be floating point and contain the probability/logits that each rater chose each category.

  • mode (Literal['counts', 'probs']) – Whether ratings will be provided as counts or probabilities.

Return type:

Tensor

Example

>>> # Ratings are provided as counts
>>> import torch
>>> from torchmetrics.functional.nominal import fleiss_kappa
>>> _ = torch.manual_seed(42)
>>> ratings = torch.randint(0, 10, size=(100, 5)).long()  # 100 samples, 5 categories, 10 raters
>>> fleiss_kappa(ratings)
tensor(0.0089)

Example

>>> # Ratings are provided as probabilities
>>> import torch
>>> from torchmetrics.functional.nominal import fleiss_kappa
>>> _ = torch.manual_seed(42)
>>> ratings = torch.randn(100, 5, 10).softmax(dim=1)  # 100 samples, 5 categories, 10 raters
>>> fleiss_kappa(ratings, mode='probs')
tensor(-0.0105)

Pearson’s Contingency Coefficient

Module Interface

class torchmetrics.nominal.PearsonsContingencyCoefficient(num_classes, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]

Compute Pearson’s Contingency Coefficient statistic.

This metric measures the association between two categorical (nominal) data series.

\[Pearson = \sqrt{\frac{\chi^2 / n}{1 + \chi^2 / n}}\]

where

\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]

where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in preds and target, respectively. Pearson’s Contingency Coefficient is a symmetric coefficient, i.e. \(Pearson(preds, target) = Pearson(target, preds)\), so order of input arguments does not matter. The output values lies in [0, 1] with 1 meaning the perfect association.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the first data series with shape (batch_size,) or (batch_size, num_classes), respectively.

  • target (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the second data series with shape (batch_size,) or (batch_size, num_classes), respectively.

As output of forward and compute the metric returns the following output:

  • pearsons_cc (Tensor): Scalar tensor containing the Pearsons Contingency Coefficient statistic.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If nan_strategy is not one of ‘replace’ and ‘drop’

  • ValueError – If nan_strategy is equal to ‘replace’ and nan_replace_value is not an int or float

Example:

>>> from torchmetrics.nominal import PearsonsContingencyCoefficient
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> pearsons_contingency_coefficient = PearsonsContingencyCoefficient(num_classes=5)
>>> pearsons_contingency_coefficient(preds, target)
tensor(0.6948)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.nominal import PearsonsContingencyCoefficient
>>> metric = PearsonsContingencyCoefficient(num_classes=5)
>>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))
>>> fig_, ax_ = metric.plot()
_images/pearsons_contingency_coefficient-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.nominal import PearsonsContingencyCoefficient
>>> metric = PearsonsContingencyCoefficient(num_classes=5)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))))
>>> fig_, ax_ = metric.plot(values)
_images/pearsons_contingency_coefficient-2.png

Functional Interface

torchmetrics.functional.nominal.pearsons_contingency_coefficient(preds, target, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Pearson’s Contingency Coefficient for measuring the association between two categorical data series.

\[Pearson = \sqrt{\frac{\chi^2 / n}{1 + \chi^2 / n}}\]

where

\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]

where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in preds and target, respectively.

Pearson’s Contingency Coefficient is a symmetric coefficient, i.e. \(Pearson(preds, target) = Pearson(target, preds)\).

The output values lies in [0, 1] with 1 meaning the perfect association.

Parameters:
  • preds (Tensor) –

    1D or 2D tensor of categorical (nominal) data:

    • 1D shape: (batch_size,)

    • 2D shape: (batch_size, num_classes)

  • target (Tensor) –

    1D or 2D tensor of categorical (nominal) data:

    • 1D shape: (batch_size,)

    • 2D shape: (batch_size, num_classes)

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Pearson’s Contingency Coefficient

Example

>>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> pearsons_contingency_coefficient(preds, target)
tensor(0.6948)
pearsons_contingency_coefficient_matrix
torchmetrics.functional.nominal.pearsons_contingency_coefficient_matrix(matrix, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Pearson’s Contingency Coefficient statistic between a set of multiple variables.

This can serve as a convenient tool to compute Pearson’s Contingency Coefficient for analyses of correlation between categorical variables in your dataset.

Parameters:
  • matrix (Tensor) –

    A tensor of categorical (nominal) data, where:

    • rows represent a number of data points

    • columns represent a number of categorical (nominal) features

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Pearson’s Contingency Coefficient statistic for a dataset of categorical variables

Example

>>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> pearsons_contingency_coefficient_matrix(matrix)
tensor([[1.0000, 0.2326, 0.1959, 0.2262, 0.2989],
        [0.2326, 1.0000, 0.1386, 0.1895, 0.1329],
        [0.1959, 0.1386, 1.0000, 0.1840, 0.2335],
        [0.2262, 0.1895, 0.1840, 1.0000, 0.2737],
        [0.2989, 0.1329, 0.2335, 0.2737, 1.0000]])

Theil’s U

Module Interface

class torchmetrics.nominal.TheilsU(num_classes, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]

Compute Theil’s U statistic measuring the association between two categorical (nominal) data series.

\[U(X|Y) = \frac{H(X) - H(X|Y)}{H(X)}\]

where \(H(X)\) is entropy of variable \(X\) while \(H(X|Y)\) is the conditional entropy of \(X\) given \(Y\). It is also know as the Uncertainty Coefficient. Theils’s U is an asymmetric coefficient, i.e. \(TheilsU(preds, target) \neq TheilsU(target, preds)\), so the order of the inputs matters. The output values lies in [0, 1], where a 0 means y has no information about x while value 1 means y has complete information about x.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the first data series (called X in the above definition) with shape (batch_size,) or (batch_size, num_classes), respectively.

  • target (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the second data series (called Y in the above definition) with shape (batch_size,) or (batch_size, num_classes), respectively.

As output of forward and compute the metric returns the following output:

  • theils_u (Tensor): Scalar tensor containing the Theil’s U statistic.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example:

>>> from torchmetrics.nominal import TheilsU
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(10, (10,))
>>> target = torch.randint(10, (10,))
>>> metric = TheilsU(num_classes=10)
>>> metric(preds, target)
tensor(0.8530)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.nominal import TheilsU
>>> metric = TheilsU(num_classes=10)
>>> metric.update(torch.randint(10, (10,)), torch.randint(10, (10,)))
>>> fig_, ax_ = metric.plot()
_images/theils_u-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.nominal import TheilsU
>>> metric = TheilsU(num_classes=10)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randint(10, (10,)), torch.randint(10, (10,))))
>>> fig_, ax_ = metric.plot(values)
_images/theils_u-2.png

Functional Interface

torchmetrics.functional.nominal.theils_u(preds, target, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Theils Uncertainty coefficient statistic measuring the association between two nominal data series.

\[U(X|Y) = \frac{H(X) - H(X|Y)}{H(X)}\]

where \(H(X)\) is entropy of variable \(X\) while \(H(X|Y)\) is the conditional entropy of \(X\) given \(Y\).

Theils’s U is an asymmetric coefficient, i.e. \(TheilsU(preds, target) \neq TheilsU(target, preds)\).

The output values lies in [0, 1]. 0 means y has no information about x while value 1 means y has complete information about x.

Parameters:
  • preds (Tensor) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)

  • target (Tensor) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Tensor containing Theil’s U statistic

Example

>>> from torchmetrics.functional.nominal import theils_u
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(10, (10,))
>>> target = torch.randint(10, (10,))
>>> theils_u(preds, target)
tensor(0.8530)
theils_u_matrix
torchmetrics.functional.nominal.theils_u_matrix(matrix, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Theil’s U statistic between a set of multiple variables.

This can serve as a convenient tool to compute Theil’s U statistic for analyses of correlation between categorical variables in your dataset.

Parameters:
  • matrix (Tensor) – A tensor of categorical (nominal) data, where: - rows represent a number of data points - columns represent a number of categorical (nominal) features

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Theil’s U statistic for a dataset of categorical variables

Example

>>> from torchmetrics.functional.nominal import theils_u_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> theils_u_matrix(matrix)
tensor([[1.0000, 0.0202, 0.0142, 0.0196, 0.0353],
        [0.0202, 1.0000, 0.0070, 0.0136, 0.0065],
        [0.0143, 0.0070, 1.0000, 0.0125, 0.0206],
        [0.0198, 0.0137, 0.0125, 1.0000, 0.0312],
        [0.0352, 0.0065, 0.0204, 0.0308, 1.0000]])

Tschuprow’s T

Module Interface

class torchmetrics.nominal.TschuprowsT(num_classes, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]

Compute Tschuprow’s T statistic measuring the association between two categorical (nominal) data series.

\[T = \sqrt{\frac{\chi^2 / n}{\sqrt{(r - 1) * (k - 1)}}}\]

where

\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]

where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in preds and target, respectively. Tschuprow’s T is a symmetric coefficient, i.e. \(T(preds, target) = T(target, preds)\), so order of input arguments does not matter. The output values lies in [0, 1] with 1 meaning the perfect association.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the first data series with shape (batch_size,) or (batch_size, num_classes), respectively.

  • target (Tensor): Either 1D or 2D tensor of categorical (nominal) data from the second data series with shape (batch_size,) or (batch_size, num_classes), respectively.

As output of forward and compute the metric returns the following output:

  • tschuprows_t (Tensor): Scalar tensor containing the Tschuprow’s T statistic.

Parameters:
  • num_classes (int) – Integer specifing the number of classes

  • bias_correction (bool) – Indication of whether to use bias correction.

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If nan_strategy is not one of ‘replace’ and ‘drop’

  • ValueError – If nan_strategy is equal to ‘replace’ and nan_replace_value is not an int or float

Example:

>>> from torchmetrics.nominal import TschuprowsT
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> tschuprows_t = TschuprowsT(num_classes=5)
>>> tschuprows_t(preds, target)
tensor(0.4930)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.nominal import TschuprowsT
>>> metric = TschuprowsT(num_classes=5)
>>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))
>>> fig_, ax_ = metric.plot()
_images/tschuprows_t-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.nominal import TschuprowsT
>>> metric = TschuprowsT(num_classes=5)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))))
>>> fig_, ax_ = metric.plot(values)
_images/tschuprows_t-2.png

Functional Interface

torchmetrics.functional.nominal.tschuprows_t(preds, target, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Tschuprow’s T statistic measuring the association between two categorical (nominal) data series.

\[T = \sqrt{\frac{\chi^2 / n}{\sqrt{(r - 1) * (k - 1)}}}\]

where

\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]

where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in preds and target, respectively.

Tschuprow’s T is a symmetric coefficient, i.e. \(T(preds, target) = T(target, preds)\).

The output values lies in [0, 1] with 1 meaning the perfect association.

Parameters:
  • preds (Tensor) –

    1D or 2D tensor of categorical (nominal) data:

    • 1D shape: (batch_size,)

    • 2D shape: (batch_size, num_classes)

  • target (Tensor) –

    1D or 2D tensor of categorical (nominal) data:

    • 1D shape: (batch_size,)

    • 2D shape: (batch_size, num_classes)

  • bias_correction (bool) – Indication of whether to use bias correction.

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Tschuprow’s T statistic

Example

>>> from torchmetrics.functional.nominal import tschuprows_t
>>> _ = torch.manual_seed(42)
>>> preds = torch.randint(0, 4, (100,))
>>> target = torch.round(preds + torch.randn(100)).clamp(0, 4)
>>> tschuprows_t(preds, target)
tensor(0.4930)
tschuprows_t_matrix
torchmetrics.functional.nominal.tschuprows_t_matrix(matrix, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]

Compute Tschuprow’s T statistic between a set of multiple variables.

This can serve as a convenient tool to compute Tschuprow’s T statistic for analyses of correlation between categorical variables in your dataset.

Parameters:
  • matrix (Tensor) –

    A tensor of categorical (nominal) data, where:

    • rows represent a number of data points

    • columns represent a number of categorical (nominal) features

  • bias_correction (bool) – Indication of whether to use bias correction.

  • nan_strategy (Literal['replace', 'drop']) – Indication of whether to replace or drop NaN values

  • nan_replace_value (Union[int, float, None]) – Value to replace NaN``s when ``nan_strategy = 'replace'

Return type:

Tensor

Returns:

Tschuprow’s T statistic for a dataset of categorical variables

Example

>>> from torchmetrics.functional.nominal import tschuprows_t_matrix
>>> _ = torch.manual_seed(42)
>>> matrix = torch.randint(0, 4, (200, 5))
>>> tschuprows_t_matrix(matrix)
tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337],
        [0.0637, 1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000, 0.0649],
        [0.0542, 0.0000, 0.0000, 1.0000, 0.1100],
        [0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])

Cosine Similarity

Functional Interface

torchmetrics.functional.pairwise_cosine_similarity(x, y=None, reduction=None, zero_diagonal=None)[source]

Calculate pairwise cosine similarity.

\[s_{cos}(x,y) = \frac{<x,y>}{||x|| \cdot ||y||} = \frac{\sum_{d=1}^D x_d \cdot y_d }{\sqrt{\sum_{d=1}^D x_i^2} \cdot \sqrt{\sum_{d=1}^D y_i^2}}\]

If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).

Parameters:
  • x (Tensor) – Tensor with shape [N, d]

  • y (Optional[Tensor]) – Tensor with shape [M, d], optional

  • reduction (Optional[Literal['mean', 'sum', 'none', None]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reduction

  • zero_diagonal (Optional[bool]) – if the diagonal of the distance matrix should be set to 0. If only \(x\) is given this defaults to True else if \(y\) is also given it defaults to False

Return type:

Tensor

Returns:

A [N,N] matrix of distances if only x is given, else a [N,M] matrix

Example

>>> import torch
>>> from torchmetrics.functional.pairwise import pairwise_cosine_similarity
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_cosine_similarity(x, y)
tensor([[0.5547, 0.8682],
        [0.5145, 0.8437],
        [0.5300, 0.8533]])
>>> pairwise_cosine_similarity(x)
tensor([[0.0000, 0.9989, 0.9996],
        [0.9989, 0.0000, 0.9998],
        [0.9996, 0.9998, 0.0000]])

Euclidean Distance

Functional Interface

torchmetrics.functional.pairwise_euclidean_distance(x, y=None, reduction=None, zero_diagonal=None)[source]

Calculate pairwise euclidean distances.

\[d_{euc}(x,y) = ||x - y||_2 = \sqrt{\sum_{d=1}^D (x_d - y_d)^2}\]

If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).

Parameters:
  • x (Tensor) – Tensor with shape [N, d]

  • y (Optional[Tensor]) – Tensor with shape [M, d], optional

  • reduction (Optional[Literal['mean', 'sum', 'none', None]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reduction

  • zero_diagonal (Optional[bool]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False

Return type:

Tensor

Returns:

A [N,N] matrix of distances if only x is given, else a [N,M] matrix

Example

>>> import torch
>>> from torchmetrics.functional.pairwise import pairwise_euclidean_distance
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_euclidean_distance(x, y)
tensor([[3.1623, 2.0000],
        [5.3852, 4.1231],
        [8.9443, 7.6158]])
>>> pairwise_euclidean_distance(x)
tensor([[0.0000, 2.2361, 5.8310],
        [2.2361, 0.0000, 3.6056],
        [5.8310, 3.6056, 0.0000]])

Linear Similarity

Functional Interface

torchmetrics.functional.pairwise_linear_similarity(x, y=None, reduction=None, zero_diagonal=None)[source]

Calculate pairwise linear similarity.

\[s_{lin}(x,y) = <x,y> = \sum_{d=1}^D x_d \cdot y_d\]

If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).

Parameters:
  • x (Tensor) – Tensor with shape [N, d]

  • y (Optional[Tensor]) – Tensor with shape [M, d], optional

  • reduction (Optional[Literal['mean', 'sum', 'none', None]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reduction

  • zero_diagonal (Optional[bool]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False

Return type:

Tensor

Returns:

A [N,N] matrix of distances if only x is given, else a [N,M] matrix

Example

>>> import torch
>>> from torchmetrics.functional.pairwise import pairwise_linear_similarity
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_linear_similarity(x, y)
tensor([[ 2.,  7.],
        [ 3., 11.],
        [ 5., 18.]])
>>> pairwise_linear_similarity(x)
tensor([[ 0., 21., 34.],
        [21.,  0., 55.],
        [34., 55.,  0.]])

Manhattan Distance

Functional Interface

torchmetrics.functional.pairwise_manhattan_distance(x, y=None, reduction=None, zero_diagonal=None)[source]

Calculate pairwise manhattan distance.

\[d_{man}(x,y) = ||x-y||_1 = \sum_{d=1}^D |x_d - y_d|\]

If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).

Parameters:
  • x (Tensor) – Tensor with shape [N, d]

  • y (Optional[Tensor]) – Tensor with shape [M, d], optional

  • reduction (Optional[Literal['mean', 'sum', 'none', None]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reduction

  • zero_diagonal (Optional[bool]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False

Return type:

Tensor

Returns:

A [N,N] matrix of distances if only x is given, else a [N,M] matrix

Example

>>> import torch
>>> from torchmetrics.functional.pairwise import pairwise_manhattan_distance
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_manhattan_distance(x, y)
tensor([[ 4.,  2.],
        [ 7.,  5.],
        [12., 10.]])
>>> pairwise_manhattan_distance(x)
tensor([[0., 3., 8.],
        [3., 0., 5.],
        [8., 5., 0.]])

Minkowski Distance

Functional Interface

torchmetrics.functional.pairwise_minkowski_distance(x, y=None, exponent=2, reduction=None, zero_diagonal=None)[source]

Calculate pairwise minkowski distances.

\[d_{minkowski}(x,y,p) = ||x - y||_p = \sqrt[p]{\sum_{d=1}^D (x_d - y_d)^p}\]

If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).

Parameters:
  • x (Tensor) – Tensor with shape [N, d]

  • y (Optional[Tensor]) – Tensor with shape [M, d], optional

  • exponent (Union[int, float]) – int or float larger than 1, exponent to which the difference between preds and target is to be raised

  • reduction (Optional[Literal['mean', 'sum', 'none', None]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reduction

  • zero_diagonal (Optional[bool]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False

Return type:

Tensor

Returns:

A [N,N] matrix of distances if only x is given, else a [N,M] matrix

Example

>>> import torch
>>> from torchmetrics.functional.pairwise import pairwise_minkowski_distance
>>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32)
>>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32)
>>> pairwise_minkowski_distance(x, y, exponent=4)
tensor([[3.0092, 2.0000],
        [5.0317, 4.0039],
        [8.1222, 7.0583]])
>>> pairwise_minkowski_distance(x, exponent=4)
tensor([[0.0000, 2.0305, 5.1547],
        [2.0305, 0.0000, 3.1383],
        [5.1547, 3.1383, 0.0000]])

Concordance Corr. Coef.

Module Interface

class torchmetrics.ConcordanceCorrCoef(num_outputs=1, **kwargs)[source]

Compute concordance correlation coefficient that measures the agreement between two variables.

\[\rho_c = \frac{2 \rho \sigma_x \sigma_y}{\sigma_x^2 + \sigma_y^2 + (\mu_x - \mu_y)^2}\]

where \(\mu_x, \mu_y\) is the means for the two variables, \(\sigma_x^2, \sigma_y^2\) are the corresponding variances and rho is the pearson correlation coefficient between the two variables.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): either single output float tensor with shape (N,) or multioutput float tensor of shape (N,d)

  • target (Tensor): either single output float tensor with shape (N,) or multioutput float tensor of shape (N,d)

As output of forward and compute the metric returns the following output:

  • concordance (Tensor): A scalar float tensor with the concordance coefficient(s) for non-multioutput input or a float tensor with shape (d,) for multioutput input

Parameters:
  • num_outputs (int) – Number of outputs in multioutput setting

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (single output regression):
>>> from torchmetrics.regression import ConcordanceCorrCoef
>>> from torch import tensor
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> concordance = ConcordanceCorrCoef()
>>> concordance(preds, target)
tensor(0.9777)
Example (multi output regression):
>>> from torchmetrics.regression import ConcordanceCorrCoef
>>> target = tensor([[3, -0.5], [2, 7]])
>>> preds = tensor([[2.5, 0.0], [2, 8]])
>>> concordance = ConcordanceCorrCoef(num_outputs=2)
>>> concordance(preds, target)
tensor([0.7273, 0.9887])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import ConcordanceCorrCoef
>>> metric = ConcordanceCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/concordance_corr_coef-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import ConcordanceCorrCoef
>>> metric = ConcordanceCorrCoef()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/concordance_corr_coef-2.png

Functional Interface

torchmetrics.functional.concordance_corrcoef(preds, target)[source]

Compute concordance correlation coefficient that measures the agreement between two variables.

\[\rho_c = \frac{2 \rho \sigma_x \sigma_y}{\sigma_x^2 + \sigma_y^2 + (\mu_x - \mu_y)^2}\]

where \(\mu_x, \mu_y\) is the means for the two variables, \(\sigma_x^2, \sigma_y^2\) are the corresponding variances and rho is the pearson correlation coefficient between the two variables.

Parameters:
  • preds (Tensor) – estimated scores

  • target (Tensor) – ground truth scores

Return type:

Tensor

Example (single output regression):
>>> from torchmetrics.functional.regression import concordance_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> concordance_corrcoef(preds, target)
tensor([0.9777])
Example (multi output regression):
>>> from torchmetrics.functional.regression import concordance_corrcoef
>>> target = torch.tensor([[3, -0.5], [2, 7]])
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> concordance_corrcoef(preds, target)
tensor([0.7273, 0.9887])

Cosine Similarity

Module Interface

class torchmetrics.CosineSimilarity(reduction='sum', **kwargs)[source]

Compute the Cosine Similarity.

\[cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} = \frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}}\]

where \(y\) is a tensor of target values, and \(x\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predicted float tensor with shape (N,d)

  • target (Tensor): Ground truth float tensor with shape (N,d)

As output of forward and compute the metric returns the following output:

  • cosine_similarity (Tensor): A float tensor with the cosine similarity

Parameters:
  • reduction (Literal['mean', 'sum', 'none', None]) – how to reduce over the batch dimension using ‘sum’, ‘mean’ or ‘none’ (taking the individual scores)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.regression import CosineSimilarity
>>> target = tensor([[0, 1], [1, 1]])
>>> preds = tensor([[0, 1], [0, 1]])
>>> cosine_similarity = CosineSimilarity(reduction = 'mean')
>>> cosine_similarity(preds, target)
tensor(0.8536)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import CosineSimilarity
>>> metric = CosineSimilarity()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/cosine_similarity-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import CosineSimilarity
>>> metric = CosineSimilarity()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/cosine_similarity-2.png

Functional Interface

torchmetrics.functional.cosine_similarity(preds, target, reduction='sum')[source]

Compute the Cosine Similarity.

\[cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} = \frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}}\]

where \(y\) is a tensor of target values, and \(x\) is a tensor of predictions.

Parameters:
  • preds (Tensor) – Predicted tensor with shape (N,d)

  • target (Tensor) – Ground truth tensor with shape (N,d)

  • reduction (Optional[str]) – The method of reducing along the batch dimension using sum, mean or taking the individual scores

Return type:

Tensor

Example

>>> from torchmetrics.functional.regression import cosine_similarity
>>> target = torch.tensor([[1, 2, 3, 4],
...                        [1, 2, 3, 4]])
>>> preds = torch.tensor([[1, 2, 3, 4],
...                       [-1, -2, -3, -4]])
>>> cosine_similarity(preds, target, 'none')
tensor([ 1.0000, -1.0000])

Explained Variance

Module Interface

class torchmetrics.ExplainedVariance(multioutput='uniform_average', **kwargs)[source]

Compute explained variance.

\[\text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)}\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model in float tensor with shape (N,) or (N, ...) (multioutput)

  • target (Tensor): Ground truth values in long tensor with shape (N,) or (N, ...) (multioutput)

As output of forward and compute the metric returns the following output:

  • explained_variance (Tensor): A tensor with the explained variance(s)

In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.

Parameters:
  • multioutput (Literal['raw_values', 'uniform_average', 'variance_weighted']) –

    Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is 'uniform_average'.):

    • 'raw_values' returns full set of scores

    • 'uniform_average' scores are uniformly averaged

    • 'variance_weighted' scores are weighted by their individual variances

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If multioutput is not one of "raw_values", "uniform_average" or "variance_weighted".

Example

>>> from torch import tensor
>>> from torchmetrics.regression import ExplainedVariance
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> explained_variance = ExplainedVariance()
>>> explained_variance(preds, target)
tensor(0.9572)
>>> target = tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance = ExplainedVariance(multioutput='raw_values')
>>> explained_variance(preds, target)
tensor([0.9677, 1.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import ExplainedVariance
>>> metric = ExplainedVariance()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/explained_variance-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import ExplainedVariance
>>> metric = ExplainedVariance()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/explained_variance-2.png

Functional Interface

torchmetrics.functional.explained_variance(preds, target, multioutput='uniform_average')[source]

Compute explained variance.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

  • multioutput (Literal['raw_values', 'uniform_average', 'variance_weighted']) –

    Defines aggregation in the case of multiple output scores. Can be one of the following strings):

    • 'raw_values' returns full set of scores

    • 'uniform_average' scores are uniformly averaged

    • 'variance_weighted' scores are weighted by their individual variances

Return type:

Union[Tensor, Sequence[Tensor]]

Example

>>> from torchmetrics.functional.regression import explained_variance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance(preds, target)
tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance(preds, target, multioutput='raw_values')
tensor([0.9677, 1.0000])

Kendall Rank Corr. Coef.

Module Interface

class torchmetrics.KendallRankCorrCoef(variant='b', t_test=False, alternative='two-sided', num_outputs=1, **kwargs)[source]

Compute Kendall Rank Correlation Coefficient.

\[tau_a = \frac{C - D}{C + D}\]

where \(C\) represents concordant pairs, \(D\) stands for discordant pairs.

\[tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}\]

where \(C\) represents concordant pairs, \(D\) stands for discordant pairs and \(T\) represents a total number of ties.

\[tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}\]

where \(C\) represents concordant pairs, \(D\) stands for discordant pairs, \(n\) is a total number of observations and \(m\) is a min of unique values in preds and target sequence.

Definitions according to Definition according to The Treatment of Ties in Ranking Problems.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Sequence of data in float tensor of either shape (N,) or (N,d)

  • target (Tensor): Sequence of data in float tensor of either shape (N,) or (N,d)

As output of forward and compute the metric returns the following output:

  • kendall (Tensor): A tensor with the correlation tau statistic, and if it is not None, the p-value of corresponding statistical test.

Parameters:
  • variant (Literal['a', 'b', 'c']) – Indication of which variant of Kendall’s tau to be used

  • t_test (bool) – Indication whether to run t-test

  • alternative (Optional[Literal['two-sided', 'less', 'greater']]) – Alternative hypothesis for t-test. Possible values: - ‘two-sided’: the rank correlation is nonzero - ‘less’: the rank correlation is negative (less than zero) - ‘greater’: the rank correlation is positive (greater than zero)

  • num_outputs (int) – Number of outputs in multioutput setting

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
Example (single output regression):
>>> from torch import tensor
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> target = tensor([3, -0.5, 2, 1])
>>> kendall = KendallRankCorrCoef()
>>> kendall(preds, target)
tensor(0.3333)
Example (multi output regression):
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> preds = tensor([[2.5, 0.0], [2, 8]])
>>> target = tensor([[3, -0.5], [2, 1]])
>>> kendall = KendallRankCorrCoef(num_outputs=2)
>>> kendall(preds, target)
tensor([1., 1.])
Example (single output regression with t-test):
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> target = tensor([3, -0.5, 2, 1])
>>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided')
>>> kendall(preds, target)
(tensor(0.3333), tensor(0.4969))
Example (multi output regression with t-test):
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> preds = tensor([[2.5, 0.0], [2, 8]])
>>> target = tensor([[3, -0.5], [2, 1]])
>>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided', num_outputs=2)
>>> kendall(preds, target)
(tensor([1., 1.]), tensor([nan, nan]))
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> metric = KendallRankCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/kendall_rank_corr_coef-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KendallRankCorrCoef
>>> metric = KendallRankCorrCoef()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/kendall_rank_corr_coef-2.png

Functional Interface

torchmetrics.functional.kendall_rank_corrcoef(preds, target, variant='b', t_test=False, alternative='two-sided')[source]

Compute Kendall Rank Correlation Coefficient.

\[tau_a = \frac{C - D}{C + D}\]

where \(C\) represents concordant pairs, \(D\) stands for discordant pairs.

\[tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}\]

where \(C\) represents concordant pairs, \(D\) stands for discordant pairs and \(T\) represents a total number of ties.

\[tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}\]

where \(C\) represents concordant pairs, \(D\) stands for discordant pairs, \(n\) is a total number of observations and \(m\) is a min of unique values in preds and target sequence.

Definitions according to Definition according to The Treatment of Ties in Ranking Problems.

Parameters:
  • preds (Tensor) – Sequence of data of either shape (N,) or (N,d)

  • target (Tensor) – Sequence of data of either shape (N,) or (N,d)

  • variant (Literal['a', 'b', 'c']) – Indication of which variant of Kendall’s tau to be used

  • t_test (bool) – Indication whether to run t-test

  • alternative (Optional[Literal['two-sided', 'less', 'greater']]) – Alternative hypothesis for t-test. Possible values: - ‘two-sided’: the rank correlation is nonzero - ‘less’: the rank correlation is negative (less than zero) - ‘greater’: the rank correlation is positive (greater than zero)

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

Correlation tau statistic (Optional) p-value of corresponding statistical test (asymptotic)

Raises:
Example (single output regression):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = torch.tensor([3, -0.5, 2, 1])
>>> kendall_rank_corrcoef(preds, target)
tensor(0.3333)
Example (multi output regression):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> target = torch.tensor([[3, -0.5], [2, 1]])
>>> kendall_rank_corrcoef(preds, target)
tensor([1., 1.])
Example (single output regression with t-test)
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = torch.tensor([3, -0.5, 2, 1])
>>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
(tensor(0.3333), tensor(0.4969))
Example (multi output regression with t-test):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> target = torch.tensor([[3, -0.5], [2, 1]])
>>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided')
    (tensor([1., 1.]), tensor([nan, nan]))

KL Divergence

Module Interface

class torchmetrics.KLDivergence(log_prob=False, reduction='mean', **kwargs)[source]

Compute the KL divergence.

\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]

Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).

As input to forward and update the metric accepts the following input:

  • p (Tensor): a data distribution with shape (N, d)

  • q (Tensor): prior or approximate distribution with shape (N, d)

As output of forward and compute the metric returns the following output:

  • kl_divergence (Tensor): A tensor with the KL divergence

Parameters:
  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.

  • reduction (Literal['mean', 'sum', 'none', None]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • TypeError – If log_prob is not an bool.

  • ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

Note

Half precision is only support on GPU for this metric

Example

>>> from torch import tensor
>>> from torchmetrics.regression import KLDivergence
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence = KLDivergence()
>>> kl_divergence(p, q)
tensor(0.0853)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
>>> fig_, ax_ = metric.plot()
_images/kl_divergence-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)))
>>> fig, ax = metric.plot(values)
_images/kl_divergence-2.png

Functional Interface

torchmetrics.functional.kl_divergence(p, q, log_prob=False, reduction='mean')[source]

Compute KL divergence.

\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]

Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).

Parameters:
  • p (Tensor) – data distribution with shape [N, d]

  • q (Tensor) – prior or approximate distribution with shape [N, d]

  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1

  • reduction (Literal['mean', 'sum', 'none', None]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

Return type:

Tensor

Example

>>> from torch import tensor
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence(p, q)
tensor(0.0853)

Log Cosh Error

Module Interface

class torchmetrics.LogCoshError(num_outputs=1, **kwargs)[source]

Compute the LogCosh Error.

\[\text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Estimated labels with shape (batch_size,) or (batch_size, num_outputs)

  • target (Tensor): Ground truth labels with shape (batch_size,) or (batch_size, num_outputs)

As output of forward and compute the metric returns the following output:

  • log_cosh_error (Tensor): A tensor with the log cosh error

Parameters:
  • num_outputs (int) – Number of outputs in multioutput setting

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (single output regression)::
>>> from torchmetrics.regression import LogCoshError
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> log_cosh_error = LogCoshError()
>>> log_cosh_error(preds, target)
tensor(0.3523)
Example (multi output regression)::
>>> from torchmetrics.regression import LogCoshError
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
>>> log_cosh_error = LogCoshError(num_outputs=3)
>>> log_cosh_error(preds, target)
tensor([0.9176, 0.4277, 0.2194])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import LogCoshError
>>> metric = LogCoshError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/log_cosh_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import LogCoshError
>>> metric = LogCoshError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/log_cosh_error-2.png

Functional Interface

torchmetrics.functional.log_cosh_error(preds, target)[source]

Compute the LogCosh Error.

\[\text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

Parameters:
  • preds (Tensor) – estimated labels with shape (batch_size,) or (batch_size, num_outputs)`

  • target (Tensor) – ground truth labels with shape (batch_size,) or (batch_size, num_outputs)`

Return type:

Tensor

Returns:

Tensor with LogCosh error

Example (single output regression)::
>>> from torchmetrics.functional.regression import log_cosh_error
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> log_cosh_error(preds, target)
tensor(0.3523)
Example (multi output regression)::
>>> from torchmetrics.functional.regression import log_cosh_error
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
>>> log_cosh_error(preds, target)
tensor([0.9176, 0.4277, 0.2194])

Mean Absolute Error (MAE)

Module Interface

class torchmetrics.MeanAbsoluteError(**kwargs)[source]

Compute Mean Absolute Error (MAE).

\[\text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} |\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output:

  • mean_absolute_error (Tensor): A tensor with the mean absolute error over the state

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.regression import MeanAbsoluteError
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> mean_absolute_error = MeanAbsoluteError()
>>> mean_absolute_error(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import MeanAbsoluteError
>>> metric = MeanAbsoluteError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/mean_absolute_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import MeanAbsoluteError
>>> metric = MeanAbsoluteError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/mean_absolute_error-2.png

Functional Interface

torchmetrics.functional.mean_absolute_error(preds, target)[source]

Compute mean absolute error.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

Return type:

Tensor

Returns:

Tensor with MAE

Example

>>> from torchmetrics.functional.regression import mean_absolute_error
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_absolute_error(x, y)
tensor(0.2500)

Mean Absolute Percentage Error (MAPE)

Module Interface

class torchmetrics.MeanAbsolutePercentageError(**kwargs)[source]

Compute Mean Absolute Percentage Error (MAPE).

\[\text{MAPE} = \frac{1}{n}\sum_{i=1}^n\frac{| y_i - \hat{y_i} |}{\max(\epsilon, | y_i |)}\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output:

  • mean_abs_percentage_error (Tensor): A tensor with the mean absolute percentage error over state

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Note

MAPE output is a non-negative floating point. Best result is 0.0 . But it is important to note that, bad predictions, can lead to arbitarily large values. Especially when some target values are close to 0. This MAPE implementation returns a very large number instead of inf.

Example

>>> from torch import tensor
>>> from torchmetrics.regression import MeanAbsolutePercentageError
>>> target = tensor([1, 10, 1e6])
>>> preds = tensor([0.9, 15, 1.2e6])
>>> mean_abs_percentage_error = MeanAbsolutePercentageError()
>>> mean_abs_percentage_error(preds, target)
tensor(0.2667)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import MeanAbsolutePercentageError
>>> metric = MeanAbsolutePercentageError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/mean_absolute_percentage_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import MeanAbsolutePercentageError
>>> metric = MeanAbsolutePercentageError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/mean_absolute_percentage_error-2.png

Functional Interface

torchmetrics.functional.mean_absolute_percentage_error(preds, target)[source]

Compute mean absolute percentage error.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

Return type:

Tensor

Returns:

Tensor with MAPE

Note

The epsilon value is taken from scikit-learn’s implementation of MAPE.

Example

>>> from torchmetrics.functional.regression import mean_absolute_percentage_error
>>> target = torch.tensor([1, 10, 1e6])
>>> preds = torch.tensor([0.9, 15, 1.2e6])
>>> mean_absolute_percentage_error(preds, target)
tensor(0.2667)

Mean Squared Error (MSE)

Module Interface

class torchmetrics.MeanSquaredError(squared=True, num_outputs=1, **kwargs)[source]

Compute mean squared error (MSE).

\[\text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output:

  • mean_squared_error (Tensor): A tensor with the mean squared error

Parameters:
  • squared (bool) – If True returns MSE value, if False returns RMSE value.

  • num_outputs (int) – Number of outputs in multioutput setting

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example::

Single output mse computation:

>>> from torch import tensor
>>> from torchmetrics.regression import MeanSquaredError
>>> target = tensor([2.5, 5.0, 4.0, 8.0])
>>> preds = tensor([3.0, 5.0, 2.5, 7.0])
>>> mean_squared_error = MeanSquaredError()
>>> mean_squared_error(preds, target)
tensor(0.8750)
Example::

Multioutput mse computation:

>>> from torch import tensor
>>> from torchmetrics.regression import MeanSquaredError
>>> target = tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
>>> preds = tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
>>> mean_squared_error = MeanSquaredError(num_outputs=3)
>>> mean_squared_error(preds, target)
tensor([1., 4., 9.])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = MeanSquaredError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/mean_squared_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = MeanSquaredError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/mean_squared_error-2.png

Functional Interface

torchmetrics.functional.mean_squared_error(preds, target, squared=True, num_outputs=1)[source]

Compute mean squared error.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

  • squared (bool) – returns RMSE value if set to False

  • num_outputs (int) – Number of outputs in multioutput setting

Return type:

Tensor

Returns:

Tensor with MSE

Example

>>> from torchmetrics.functional.regression import mean_squared_error
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_squared_error(x, y)
tensor(0.2500)

Mean Squared Log Error (MSLE)

Module Interface

class torchmetrics.MeanSquaredLogError(**kwargs)[source]

Compute mean squared logarithmic error (MSLE).

\[\text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output:

  • mean_squared_log_error (Tensor): A tensor with the mean squared log error

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torch import tensor
>>> from torchmetrics.regression import MeanSquaredLogError
>>> target = tensor([2.5, 5, 4, 8])
>>> preds = tensor([3, 5, 2.5, 7])
>>> mean_squared_log_error = MeanSquaredLogError()
>>> mean_squared_log_error(preds, target)
tensor(0.0397)

Note

Half precision is only support on GPU for this metric

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import MeanSquaredLogError
>>> metric = MeanSquaredLogError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/mean_squared_log_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import MeanSquaredLogError
>>> metric = MeanSquaredLogError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/mean_squared_log_error-2.png

Functional Interface

torchmetrics.functional.mean_squared_log_error(preds, target)[source]

Compute mean squared log error.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

Return type:

Tensor

Returns:

Tensor with RMSLE

Example

>>> from torchmetrics.functional.regression import mean_squared_log_error
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_squared_log_error(x, y)
tensor(0.0207)

Note

Half precision is only support on GPU for this metric

Minkowski Distance

Module Interface

class torchmetrics.MinkowskiDistance(p, **kwargs)[source]

Compute Minkowski Distance.

\[d_{\text{Minkowski}} = \sum_{i}^N (| y_i - \hat{y_i} |^p)^\frac{1}{p}\]
where
math:

y is a tensor of target values,

math:

hat{y} is a tensor of predictions,

math:

p is a non-negative integer or floating-point number

This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski distance with p=2.

Parameters:
  • p (float) – int or float larger than 1, exponent to which the difference between preds and target is to be raised

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.regression import MinkowskiDistance
>>> target = tensor([1.0, 2.8, 3.5, 4.5])
>>> preds = tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance = MinkowskiDistance(3)
>>> minkowski_distance(preds, target)
tensor(5.1220)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import MinkowskiDistance
>>> metric = MinkowskiDistance(p=3)
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/minkowski_distance-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import MinkowskiDistance
>>> metric = MinkowskiDistance(p=3)
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/minkowski_distance-2.png

Functional Interface

torchmetrics.functional.minkowski_distance(preds, targets, p)[source]

Compute the Minkowski distance.

\[\begin{split}d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}\end{split}\]

This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski distance with p=2.

Parameters:
  • preds (Tensor) – estimated labels of type Tensor

  • targets (Tensor) – ground truth labels of type Tensor

  • p (float) – int or float larger than 1, exponent to which the difference between preds and target is to be raised

Return type:

Tensor

Returns:

Tensor with the Minkowski distance

Example

>>> from torchmetrics.functional.regression import minkowski_distance
>>> x = torch.tensor([1.0, 2.8, 3.5, 4.5])
>>> y = torch.tensor([6.1, 2.11, 3.1, 5.6])
>>> minkowski_distance(x, y, p=3)
tensor(5.1220)

Pearson Corr. Coef.

Module Interface

class torchmetrics.PearsonCorrCoef(num_outputs=1, **kwargs)[source]

Compute Pearson Correlation Coefficient.

\[P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y}\]

Where \(y\) is a tensor of target values, and \(x\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): either single output float tensor with shape (N,) or multioutput float tensor of shape (N,d)

  • target (Tensor): either single output tensor with shape (N,) or multioutput tensor of shape (N,d)

As output of forward and compute the metric returns the following output:

  • pearson (Tensor): A tensor with the Pearson Correlation Coefficient

Parameters:
  • num_outputs (int) – Number of outputs in multioutput setting

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (single output regression):
>>> from torchmetrics.regression import PearsonCorrCoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson = PearsonCorrCoef()
>>> pearson(preds, target)
tensor(0.9849)
Example (multi output regression):
>>> from torchmetrics.regression import PearsonCorrCoef
>>> target = torch.tensor([[3, -0.5], [2, 7]])
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> pearson = PearsonCorrCoef(num_outputs=2)
>>> pearson(preds, target)
tensor([1., 1.])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import PearsonCorrCoef
>>> metric = PearsonCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/pearson_corr_coef-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import PearsonCorrCoef
>>> metric = PearsonCorrCoef()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/pearson_corr_coef-2.png

Functional Interface

torchmetrics.functional.pearson_corrcoef(preds, target)[source]

Compute pearson correlation coefficient.

Parameters:
  • preds (Tensor) – estimated scores

  • target (Tensor) – ground truth scores

Return type:

Tensor

Example (single output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> pearson_corrcoef(preds, target)
tensor(0.9849)
Example (multi output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef
>>> target = torch.tensor([[3, -0.5], [2, 7]])
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> pearson_corrcoef(preds, target)
tensor([1., 1.])

R2 Score

Module Interface

class torchmetrics.R2Score(num_outputs=1, adjusted=0, multioutput='uniform_average', **kwargs)[source]

Compute r2 score also known as R2 Score_Coefficient Determination.

\[R^2 = 1 - \frac{SS_{res}}{SS_{tot}}\]

where \(SS_{res}=\sum_i (y_i - f(x_i))^2\) is the sum of residual squares, and \(SS_{tot}=\sum_i (y_i - \bar{y})^2\) is total sum of squares. Can also calculate adjusted r2 score given by

\[R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}\]

where the parameter \(k\) (the number of independent regressors) should be provided as the adjusted argument. The score is only proper defined when \(SS_{tot}\neq 0\), which can happen for near constant targets. In this case a score of 0 is returned. By definition the score is bounded between 0 and 1, where 1 corresponds to the predictions exactly matching the targets.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model in float tensor with shape (N,) or (N, M) (multioutput)

  • target (Tensor): Ground truth values in float tensor with shape (N,) or (N, M) (multioutput)

As output of forward and compute the metric returns the following output:

  • r2score (Tensor): A tensor with the r2 score(s)

In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument multioutput for changing this behavior.

Parameters:
  • num_outputs (int) – Number of outputs in multioutput setting

  • adjusted (int) – number of independent regressors for calculating adjusted r2 score.

  • multioutput (str) –

    Defines aggregation in the case of multiple output scores. Can be one of the following strings:

    • 'raw_values' returns full set of scores

    • 'uniform_average' scores are uniformly averaged

    • 'variance_weighted' scores are weighted by their individual variances

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If adjusted parameter is not an integer larger or equal to 0.

  • ValueError – If multioutput is not one of "raw_values", "uniform_average" or "variance_weighted".

Example

>>> from torchmetrics.regression import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import R2Score
>>> metric = R2Score()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/r2_score-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import R2Score
>>> metric = R2Score()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/r2_score-2.png

Functional Interface

torchmetrics.functional.r2_score(preds, target, adjusted=0, multioutput='uniform_average')[source]

Compute r2 score also known as R2 Score_Coefficient Determination.

\[R^2 = 1 - \frac{SS_{res}}{SS_{tot}}\]

where \(SS_{res}=\sum_i (y_i - f(x_i))^2\) is the sum of residual squares, and \(SS_{tot}=\sum_i (y_i - \bar{y})^2\) is total sum of squares. Can also calculate adjusted r2 score given by

\[R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}\]

where the parameter \(k\) (the number of independent regressors) should be provided as the adjusted argument.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

  • adjusted (int) – number of independent regressors for calculating adjusted r2 score.

  • multioutput (str) –

    Defines aggregation in the case of multiple output scores. Can be one of the following strings:

    • 'raw_values' returns full set of scores

    • 'uniform_average' scores are uniformly averaged

    • 'variance_weighted' scores are weighted by their individual variances

Raises:
  • ValueError – If both preds and targets are not 1D or 2D tensors.

  • ValueError – If len(preds) is less than 2 since at least 2 sampels are needed to calculate r2 score.

  • ValueError – If multioutput is not one of raw_values, uniform_average or variance_weighted.

  • ValueError – If adjusted is not an integer greater than 0.

Return type:

Tensor

Example

>>> from torchmetrics.functional.regression import r2_score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2_score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2_score(preds, target, multioutput='raw_values')
tensor([0.9654, 0.9082])

Relative Squared Error (RSE)

Module Interface

class torchmetrics.RelativeSquaredError(num_outputs=1, squared=True, **kwargs)[source]

Computes the relative squared error (RSE).

\[\text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}\]

Where \(y\) is a tensor of target values with mean \(\overline{y}\), and \(\hat{y}\) is a tensor of predictions.

If num_outputs > 1, the returned value is averaged over all the outputs.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model in float tensor with shape (N,) or (N, M) (multioutput)

  • target (Tensor): Ground truth values in float tensor with shape (N,) or (N, M) (multioutput)

As output of forward and compute the metric returns the following output:

  • rse (Tensor): A tensor with the RSE score(s)

Parameters:
  • num_outputs (int) – Number of outputs in multioutput setting

  • squared (bool) – If True returns RSE value, if False returns RRSE value.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.regression import RelativeSquaredError
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error = RelativeSquaredError()
>>> relative_squared_error(preds, target)
tensor(0.0514)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import RelativeSquaredError
>>> metric = RelativeSquaredError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/rse-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import RelativeSquaredError
>>> metric = RelativeSquaredError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/rse-2.png

Functional Interface

torchmetrics.functional.relative_squared_error(preds, target, squared=True)[source]

Computes the relative squared error (RSE).

\[\text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}\]

Where \(y\) is a tensor of target values with mean \(\overline{y}\), and \(\hat{y}\) is a tensor of predictions.

If preds and targets are 2D tensors, the RSE is averaged over the second dim.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

  • squared (bool) – returns RRSE value if set to False

Return type:

Tensor

Returns:

Tensor with RSE

Example

>>> from torchmetrics.functional.regression import relative_squared_error
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> relative_squared_error(preds, target)
tensor(0.0514)

Spearman Corr. Coef.

Module Interface

class torchmetrics.SpearmanCorrCoef(num_outputs=1, **kwargs)[source]

Compute spearmans rank correlation coefficient.

where \(rg_x\) and \(rg_y\) are the rank associated to the variables \(x\) and \(y\). Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model in float tensor with shape (N,d)

  • target (Tensor): Ground truth values in float tensor with shape (N,d)

As output of forward and compute the metric returns the following output:

  • spearman (Tensor): A tensor with the spearman correlation(s)

Parameters:
  • num_outputs (int) – Number of outputs in multioutput setting

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example (single output regression):
>>> from torch import tensor
>>> from torchmetrics.regression import SpearmanCorrCoef
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> spearman = SpearmanCorrCoef()
>>> spearman(preds, target)
tensor(1.0000)
Example (multi output regression):
>>> from torchmetrics.regression import SpearmanCorrCoef
>>> target = tensor([[3, -0.5], [2, 7]])
>>> preds = tensor([[2.5, 0.0], [2, 8]])
>>> spearman = SpearmanCorrCoef(num_outputs=2)
>>> spearman(preds, target)
tensor([1.0000, 1.0000])
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import SpearmanCorrCoef
>>> metric = SpearmanCorrCoef()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/spearman_corr_coef-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import SpearmanCorrCoef
>>> metric = SpearmanCorrCoef()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/spearman_corr_coef-2.png

Functional Interface

torchmetrics.functional.spearman_corrcoef(preds, target)[source]

Compute spearmans rank correlation coefficient.

where \(rg_x\) and \(rg_y\) are the rank associated to the variables x and y. Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.

Parameters:
  • preds (Tensor) – estimated scores

  • target (Tensor) – ground truth scores

Return type:

Tensor

Example (single output regression):
>>> from torchmetrics.functional.regression import spearman_corrcoef
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> spearman_corrcoef(preds, target)
tensor(1.0000)
Example (multi output regression):
>>> from torchmetrics.functional.regression import spearman_corrcoef
>>> target = torch.tensor([[3, -0.5], [2, 7]])
>>> preds = torch.tensor([[2.5, 0.0], [2, 8]])
>>> spearman_corrcoef(preds, target)
tensor([1.0000, 1.0000])

Symmetric Mean Absolute Percentage Error (SMAPE)

Module Interface

class torchmetrics.SymmetricMeanAbsolutePercentageError(**kwargs)[source]

Compute symmetric mean absolute percentage error (SMAPE).

\[\text{SMAPE} = \frac{2}{n}\sum_1^n\frac{| y_i - \hat{y_i} |}{\max(| y_i | + | \hat{y_i} |, \epsilon)}\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output:

  • smape (Tensor): A tensor with non-negative floating point smape value between 0 and 1

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError
>>> target = tensor([1, 10, 1e6])
>>> preds = tensor([0.9, 15, 1.2e6])
>>> smape = SymmetricMeanAbsolutePercentageError()
>>> smape(preds, target)
tensor(0.2290)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError
>>> metric = SymmetricMeanAbsolutePercentageError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/symmetric_mean_absolute_percentage_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError
>>> metric = SymmetricMeanAbsolutePercentageError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/symmetric_mean_absolute_percentage_error-2.png

Functional Interface

torchmetrics.functional.symmetric_mean_absolute_percentage_error(preds, target)[source]

Compute symmetric mean absolute percentage error (SMAPE).

\[\text{SMAPE} = \frac{2}{n}\sum_1^n\frac{| y_i - \hat{y_i} |}{max(| y_i | + | \hat{y_i} |, \epsilon)}\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

Return type:

Tensor

Returns:

Tensor with SMAPE.

Example

>>> from torchmetrics.functional.regression import symmetric_mean_absolute_percentage_error
>>> target = torch.tensor([1, 10, 1e6])
>>> preds = torch.tensor([0.9, 15, 1.2e6])
>>> symmetric_mean_absolute_percentage_error(preds, target)
tensor(0.2290)

Tweedie Deviance Score

Module Interface

class torchmetrics.TweedieDevianceScore(power=0.0, **kwargs)[source]

Compute the Tweedie Deviance Score.

\[\begin{split}deviance\_score(\hat{y},y) = \begin{cases} (\hat{y} - y)^2, & \text{for }p=0\\ 2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y), & \text{for }p=1\\ 2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1), & \text{for }p=2\\ 2 * (\frac{(max(y,0))^{2 - p}}{(1 - p)(2 - p)} - \frac{y(\hat{y})^{1 - p}}{1 - p} + \frac{( \hat{y})^{2 - p}}{2 - p}), & \text{otherwise} \end{cases}\end{split}\]

where \(y\) is a tensor of targets values, \(\hat{y}\) is a tensor of predictions, and \(p\) is the power.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predicted float tensor with shape (N,...)

  • target (Tensor): Ground truth float tensor with shape (N,...)

As output of forward and compute the metric returns the following output:

  • deviance_score (Tensor): A tensor with the deviance score

Parameters:
  • power (float) –

    • power < 0 : Extreme stable distribution. (Requires: preds > 0.)

    • power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)

    • power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)

    • 1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)

    • power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)

    • power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)

    • otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.regression import TweedieDevianceScore
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> deviance_score = TweedieDevianceScore(power=2)
>>> deviance_score(preds, targets)
tensor(1.2083)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import TweedieDevianceScore
>>> metric = TweedieDevianceScore()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/tweedie_deviance_score-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import TweedieDevianceScore
>>> metric = TweedieDevianceScore()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/tweedie_deviance_score-2.png

Functional Interface

torchmetrics.functional.tweedie_deviance_score(preds, targets, power=0.0)[source]

Compute the Tweedie Deviance Score.

\[\begin{split}deviance\_score(\hat{y},y) = \begin{cases} (\hat{y} - y)^2, & \text{for }p=0\\ 2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y), & \text{for }p=1\\ 2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1), & \text{for }p=2\\ 2 * (\frac{(max(y,0))^{2 - p}}{(1 - p)(2 - p)} - \frac{y(\hat{y})^{1 - p}}{1 - p} + \frac{( \hat{y})^{2 - p}}{2 - p}), & \text{otherwise} \end{cases}\end{split}\]

where \(y\) is a tensor of targets values, \(\hat{y}\) is a tensor of predictions, and \(p\) is the power.

Parameters:
  • preds (Tensor) – Predicted tensor with shape (N,...)

  • targets (Tensor) – Ground truth tensor with shape (N,...)

  • power (float) –

    • power < 0 : Extreme stable distribution. (Requires: preds > 0.)

    • power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)

    • power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)

    • 1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)

    • power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)

    • power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)

    • otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)

Return type:

Tensor

Example

>>> from torchmetrics.functional.regression import tweedie_deviance_score
>>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0])
>>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0])
>>> tweedie_deviance_score(preds, targets, power=2)
tensor(1.2083)

Weighted MAPE

Module Interface

class torchmetrics.WeightedMeanAbsolutePercentageError(**kwargs)[source]

Compute weighted mean absolute percentage error (WMAPE).

The output of WMAPE metric is a non-negative floating point, where the optimal value is 0. It is computes as:

\[\text{WMAPE} = \frac{\sum_{t=1}^n | y_t - \hat{y}_t | }{\sum_{t=1}^n |y_t| }\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth float tensor with shape (N,d)

As output of forward and compute the metric returns the following output:

  • wmape (Tensor): A tensor with non-negative floating point wmape value between 0 and 1

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(20,)
>>> target = torch.randn(20,)
>>> wmape = WeightedMeanAbsolutePercentageError()
>>> wmape(preds, target)
tensor(1.3967)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError
>>> metric = WeightedMeanAbsolutePercentageError()
>>> metric.update(randn(10,), randn(10,))
>>> fig_, ax_ = metric.plot()
_images/weighted_mean_absolute_percentage_error-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError
>>> metric = WeightedMeanAbsolutePercentageError()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,), randn(10,)))
>>> fig, ax = metric.plot(values)
_images/weighted_mean_absolute_percentage_error-2.png

Functional Interface

torchmetrics.functional.weighted_mean_absolute_percentage_error(preds, target)[source]

Compute weighted mean absolute percentage error (WMAPE).

The output of WMAPE metric is a non-negative floating point, where the optimal value is 0. It is computes as:

\[\text{WMAPE} = \frac{\sum_{t=1}^n | y_t - \hat{y}_t | }{\sum_{t=1}^n |y_t| }\]

Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.

Parameters:
  • preds (Tensor) – estimated labels

  • target (Tensor) – ground truth labels

Return type:

Tensor

Returns:

Tensor with WMAPE.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(20,)
>>> target = torch.randn(20,)
>>> weighted_mean_absolute_percentage_error(preds, target)
tensor(1.3967)

Retrieval Fall-Out

Module Interface

class torchmetrics.retrieval.RetrievalFallOut(empty_target_action='pos', ignore_index=None, top_k=None, **kwargs)[source]

Compute Fall-out.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • fo@k (Tensor): A tensor with the computed metric

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a negative target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

Example

>>> from torchmetrics.retrieval import RetrievalFallOut
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> fo = RetrievalFallOut(top_k=2)
>>> fo(preds, target, indexes=indexes)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalFallOut
>>> # Example plotting a single value
>>> metric = RetrievalFallOut()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/fall_out-1.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalFallOut
>>> # Example plotting multiple values
>>> metric = RetrievalFallOut()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/fall_out-2.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_fall_out(preds, target, top_k=None)[source]

Compute the Fall-out for information retrieval, as explained in IR Fall-out.

Fall-out is the fraction of non-relevant documents retrieved among all the non-relevant documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure Fall-out@K, top_k must be a positive integer.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type:

Tensor

Returns:

A single-value tensor with the fall-out (at top_k) of the predictions preds w.r.t. the labels target

Raises:

ValueError – If top_k parameter is not None or an integer larger than 0

Example

>>> from  torchmetrics.functional import retrieval_fall_out
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_fall_out(preds, target, top_k=2)
tensor(1.)

Retrieval Hit Rate

Module Interface

class torchmetrics.retrieval.RetrievalHitRate(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]

Compute IR HitRate.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • hr@k (Tensor): A single-value tensor with the hit rate (at top_k) of the predictions preds w.r.t. the labels target

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalHitRate
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([True, False, False, False, True, False, True])
>>> hr2 = RetrievalHitRate(top_k=2)
>>> hr2(preds, target, indexes=indexes)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalHitRate
>>> # Example plotting a single value
>>> metric = RetrievalHitRate()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/hit_rate-1.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalHitRate
>>> # Example plotting multiple values
>>> metric = RetrievalHitRate()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/hit_rate-2.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_hit_rate(preds, target, top_k=None)[source]

Compute the hit rate for information retrieval.

The hit rate is 1.0 if there is at least one relevant document among all the top k retrieved documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure HitRate@K, top_k must be a positive integer.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type:

Tensor

Returns:

A single-value tensor with the hit rate (at top_k) of the predictions preds w.r.t. the labels

target.

Raises:

ValueError – If top_k parameter is not None or an integer larger than 0

Example

>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_hit_rate(preds, target, top_k=2)
tensor(1.)

Retrieval Mean Average Precision (MAP)

Module Interface

class torchmetrics.retrieval.RetrievalMAP(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]

Compute Mean Average Precision.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • map@k (Tensor): A single-value tensor with the mean average precision (MAP) of the predictions preds w.r.t. the labels target.

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalMAP
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> rmap = RetrievalMAP()
>>> rmap(preds, target, indexes=indexes)
tensor(0.7917)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalMAP
>>> # Example plotting a single value
>>> metric = RetrievalMAP()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/map-1.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalMAP
>>> # Example plotting multiple values
>>> metric = RetrievalMAP()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/map-2.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_average_precision(preds, target, top_k=None)[source]

Compute average precision (for information retrieval), as explained in IR Average precision.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type:

Tensor

Returns:

a single-value tensor with the average precision (AP) of the predictions preds w.r.t. the labels target.

Raises:

ValueError – If top_k is not None or an integer larger than 0.

Example

>>> from torchmetrics.functional.retrieval import retrieval_average_precision
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_average_precision(preds, target)
tensor(0.8333)

Retrieval Mean Reciprocal Rank (MRR)

Module Interface

class torchmetrics.retrieval.RetrievalMRR(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]

Compute Mean Reciprocal Rank.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • mrr@k (Tensor): A single-value tensor with the reciprocal rank (RR) of the predictions preds w.r.t. the labels target.

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalMRR
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> mrr = RetrievalMRR()
>>> mrr(preds, target, indexes=indexes)
tensor(0.7500)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalMRR
>>> # Example plotting a single value
>>> metric = RetrievalMRR()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/mrr-1.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalMRR
>>> # Example plotting multiple values
>>> metric = RetrievalMRR()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/mrr-2.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_reciprocal_rank(preds, target, top_k=None)[source]

Compute reciprocal rank (for information retrieval). See Mean Reciprocal Rank.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type:

Tensor

Returns:

a single-value tensor with the reciprocal rank (RR) of the predictions preds wrt the labels target.

Raises:

ValueError – If top_k is not None or an integer larger than 0.

Example

>>> from torchmetrics.functional.retrieval import retrieval_reciprocal_rank
>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([False, True, False])
>>> retrieval_reciprocal_rank(preds, target)
tensor(0.5000)

Retrieval Normalized DCG

Module Interface

class torchmetrics.retrieval.RetrievalNormalizedDCG(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]

Compute Normalized Discounted Cumulative Gain.

Works with binary or positive integer target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • ndcg@k (Tensor): A single-value tensor with the nDCG of the predictions preds w.r.t. the labels target

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalNormalizedDCG
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> ndcg = RetrievalNormalizedDCG()
>>> ndcg(preds, target, indexes=indexes)
tensor(0.8467)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalNormalizedDCG
>>> # Example plotting a single value
>>> metric = RetrievalNormalizedDCG()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/normalized_dcg-1.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalNormalizedDCG
>>> # Example plotting multiple values
>>> metric = RetrievalNormalizedDCG()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/normalized_dcg-2.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_normalized_dcg(preds, target, top_k=None)[source]

Compute Normalized Discounted Cumulative Gain (for information retrieval).

preds and target should be of the same shape and live on the same device. target must be either bool or integers and preds must be float, otherwise an error is raised.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document relevance.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type:

Tensor

Returns:

A single-value tensor with the nDCG of the predictions preds w.r.t. the labels target.

Raises:

ValueError – If top_k parameter is not None or an integer larger than 0

Example

>>> from torchmetrics.functional.retrieval import retrieval_normalized_dcg
>>> preds = torch.tensor([.1, .2, .3, 4, 70])
>>> target = torch.tensor([10, 0, 0, 1, 5])
>>> retrieval_normalized_dcg(preds, target)
tensor(0.6957)

Retrieval Precision

Module Interface

class torchmetrics.retrieval.RetrievalPrecision(empty_target_action='neg', ignore_index=None, top_k=None, adaptive_k=False, **kwargs)[source]

Compute IR Precision.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • p@k (Tensor): A single-value tensor with the precision (at top_k) of the predictions preds w.r.t. the labels target

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • adaptive_k (bool) – Adjust top_k to min(k, number of documents) for each query

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

  • ValueError – If adaptive_k is not boolean.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalPrecision(top_k=2)
>>> p2(preds, target, indexes=indexes)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalPrecision
>>> # Example plotting a single value
>>> metric = RetrievalPrecision()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/precision-11.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalPrecision
>>> # Example plotting multiple values
>>> metric = RetrievalPrecision()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/precision-21.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_precision(preds, target, top_k=None, adaptive_k=False)[source]

Compute the precision metric for information retrieval.

Precision is the fraction of relevant documents among all the retrieved documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure Precision@K, top_k must be a positive integer.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

  • adaptive_k (bool) – adjust k to min(k, number of documents) for each query

Return type:

Tensor

Returns:

A single-value tensor with the precision (at top_k) of the predictions preds w.r.t. the labels

target.

Raises:
  • ValueError – If top_k is not None or an integer larger than 0.

  • ValueError – If adaptive_k is not boolean.

Example

>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_precision(preds, target, top_k=2)
tensor(0.5000)

Precision Recall Curve

Module Interface

class torchmetrics.retrieval.RetrievalPrecisionRecallCurve(max_k=None, adaptive_k=False, empty_target_action='neg', ignore_index=None, **kwargs)[source]

Compute precision-recall pairs for different k (from 1 to max_k).

In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by the top k retrieved documents. Recall is the fraction of relevant documents retrieved among all the relevant documents. Precision is the fraction of relevant documents among all the retrieved documents. For each such set, precision and recall values can be plotted to give a recall-precision curve.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • precisions (Tensor): A tensor with the fraction of relevant documents among all the retrieved documents.

  • recalls (Tensor): A tensor with the fraction of relevant documents retrieved among all the relevant documents

  • top_k (Tensor): A tensor with k from 1 to max_k

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • max_k (Optional[int]) – Calculate recall and precision for all possible top k from 1 to max_k (default: None, which considers all possible top k)

  • adaptive_k (bool) – adjust k to min(k, number of documents) for each query

  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If max_k parameter is not None or not an integer larger than 0.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalPrecisionRecallCurve
>>> indexes = tensor([0, 0, 0, 0, 1, 1, 1])
>>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5])
>>> target = tensor([True, False, False, True, True, False, True])
>>> r = RetrievalPrecisionRecallCurve(max_k=4)
>>> precisions, recalls, top_k = r(preds, target, indexes=indexes)
>>> precisions
tensor([1.0000, 0.5000, 0.6667, 0.5000])
>>> recalls
tensor([0.5000, 0.5000, 1.0000, 1.0000])
>>> top_k
tensor([1, 2, 3, 4])
plot(curve=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • curve (Optional[Tuple[Tensor, Tensor, Tensor]]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalPrecisionRecallCurve
>>> # Example plotting a single value
>>> metric = RetrievalPrecisionRecallCurve()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/precision_recall_curve-11.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_precision_recall_curve(preds, target, max_k=None, adaptive_k=False)[source]

Compute precision-recall pairs for different k (from 1 to max_k).

In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by the top k retrieved documents.

Recall is the fraction of relevant documents retrieved among all the relevant documents. Precision is the fraction of relevant documents among all the retrieved documents.

For each such set, precision and recall values can be plotted to give a recall-precision curve.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • max_k (Optional[int]) – Calculate recall and precision for all possible top k from 1 to max_k (default: None, which considers all possible top k)

  • adaptive_k (bool) – adjust max_k to min(max_k, number of documents) for each query

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

Tensor with the precision values for each k (at top_k) from 1 to max_k Tensor with the recall values for each k (at top_k) from 1 to max_k Tensor with all possibles k

Raises:
  • ValueError – If max_k is not None or an integer larger than 0.

  • ValueError – If adaptive_k is not boolean.

Example

>>> from torch import tensor
>>> from  torchmetrics.functional import retrieval_precision_recall_curve
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> precisions, recalls, top_k = retrieval_precision_recall_curve(preds, target, max_k=2)
>>> precisions
tensor([1.0000, 0.5000])
>>> recalls
tensor([0.5000, 0.5000])
>>> top_k
tensor([1, 2])

Retrieval R-Precision

Module Interface

class torchmetrics.retrieval.RetrievalRPrecision(empty_target_action='neg', ignore_index=None, **kwargs)[source]

Compute IR R-Precision.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • rp (Tensor): A single-value tensor with the r-precision of the predictions preds w.r.t. the labels target.

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalRPrecision
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> p2 = RetrievalRPrecision()
>>> p2(preds, target, indexes=indexes)
tensor(0.7500)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalRPrecision
>>> # Example plotting a single value
>>> metric = RetrievalRPrecision()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/r_precision-1.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalRPrecision
>>> # Example plotting multiple values
>>> metric = RetrievalRPrecision()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/r_precision-2.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_r_precision(preds, target)[source]

Compute the r-precision metric for information retrieval.

R-Precision is the fraction of relevant documents among all the top k retrieved documents where k is equal to the total number of relevant documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure Precision@K, top_k must be a positive integer.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

Return type:

Tensor

Returns:

A single-value tensor with the r-precision of the predictions preds w.r.t. the labels target.

Example

>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_r_precision(preds, target)
tensor(0.5000)

Retrieval Recall

Module Interface

class torchmetrics.retrieval.RetrievalRecall(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]

Compute IR Recall.

Works with binary target data. Accepts float predictions from a model output.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A float tensor of shape (N, ...)

  • target (Tensor): A long or bool tensor of shape (N, ...)

  • indexes (Tensor): A long tensor of shape (N, ...) which indicate to which query a prediction belongs

As output to forward and compute the metric returns the following output:

  • r@k (Tensor): A single-value tensor with the recall (at top_k) of the predictions preds w.r.t. the labels target

All indexes, preds and target must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape (N, M) is treated as (N * M, ). Predictions will be first grouped by indexes and then will be computed as the mean of the metric over each query.

Parameters:
  • empty_target_action (str) –

    Specify what to do with queries that do not have at least a positive target. Choose from:

    • 'neg': those queries count as 0.0 (default)

    • 'pos': those queries count as 1.0

    • 'skip': skip those queries; if all queries are skipped, 0.0 is returned

    • 'error': raise a ValueError

  • ignore_index (Optional[int]) – Ignore predictions where the target is equal to this number.

  • top_k (Optional[int]) – Consider only the top k elements for each query (default: None, which considers them all)

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If empty_target_action is not one of error, skip, neg or pos.

  • ValueError – If ignore_index is not None or an integer.

  • ValueError – If top_k is not None or not an integer greater than 0.

Example

>>> from torch import tensor
>>> from torchmetrics.retrieval import RetrievalRecall
>>> indexes = tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = tensor([False, False, True, False, True, False, True])
>>> r2 = RetrievalRecall(top_k=2)
>>> r2(preds, target, indexes=indexes)
tensor(0.7500)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> import torch
>>> from torchmetrics.retrieval import RetrievalRecall
>>> # Example plotting a single value
>>> metric = RetrievalRecall()
>>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
_images/recall-11.png
>>> import torch
>>> from torchmetrics.retrieval import RetrievalRecall
>>> # Example plotting multiple values
>>> metric = RetrievalRecall()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))))
>>> fig, ax = metric.plot(values)
_images/recall-21.png

Functional Interface

torchmetrics.functional.retrieval.retrieval_recall(preds, target, top_k=None)[source]

Compute the recall metric for information retrieval.

Recall is the fraction of relevant documents retrieved among all the relevant documents.

preds and target should be of the same shape and live on the same device. If no target is True, 0 is returned. target must be either bool or integers and preds must be float, otherwise an error is raised. If you want to measure Recall@K, top_k must be a positive integer.

Parameters:
  • preds (Tensor) – estimated probabilities of each document to be relevant.

  • target (Tensor) – ground truth about each document being relevant or not.

  • top_k (Optional[int]) – consider only the top k elements (default: None, which considers them all)

Return type:

Tensor

Returns:

A single-value tensor with the recall (at top_k) of the predictions preds w.r.t. the labels target.

Raises:

ValueError – If top_k parameter is not None or an integer larger than 0

Example

>>> from  torchmetrics.functional import retrieval_recall
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> retrieval_recall(preds, target, top_k=2)
tensor(0.5000)

BERT Score

Module Interface

class torchmetrics.text.bert.BERTScore(model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None, **kwargs)[source]

Bert_score Evaluating Text Generation for measuring text similarity.

BERT leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks. This implemenation follows the original implementation from BERT_score.

As input to forward and update the metric accepts the following input:

  • preds (List): An iterable of predicted sentences

  • target (List): An iterable of reference sentences

As output of forward and compute the metric returns the following output:

  • score (Dict): A dictionary containing the keys precision, recall and f1 with corresponding values

Parameters:
  • preds – An iterable of predicted sentences.

  • target – An iterable of target sentences.

  • model_type – A name or a model path used to load transformers pretrained model.

  • num_layers (Optional[int]) – A layer of representation to use.

  • all_layers (bool) – An indication of whether the representation from all model’s layers should be used. If all_layers=True, the argument num_layers is ignored.

  • model (Optional[Module]) – A user’s own model. Must be of torch.nn.Module instance.

  • user_tokenizer (Optional[Any]) – A user’s own tokenizer used with the own model. This must be an instance with the __call__ method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing “input_ids” and “attention_mask” represented by Tensor. It is up to the user’s model of whether “input_ids” is a Tensor of input ids or embedding vectors. This tokenizer must prepend an equivalent of [CLS] token and append an equivalent of [SEP] token as transformers tokenizer does.

  • user_forward_fn (Optional[Callable[[Module, Dict[str, Tensor]], Tensor]]) – A user’s own forward function used in a combination with user_model. This function must take user_model and a python dictionary of containing "input_ids" and "attention_mask" represented by Tensor as an input and return the model’s output represented by the single Tensor.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.

  • idf (bool) – An indication whether normalization using inverse document frequencies should be used.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (int) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • return_hash (bool) – An indication of whether the correspodning hash_code should be returned.

  • lang (str) – A language of input sentences.

  • rescale_with_baseline (bool) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model from transformers model is used, the corresponding baseline is downloaded from the original bert-score package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score.

  • baseline_path (Optional[str]) – A path to the user’s own local csv/tsv file with the baseline scale.

  • baseline_url (Optional[str]) – A url path to the user’s own csv/tsv file with the baseline scale.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from pprint import pprint
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> bertscore = BERTScore()
>>> pprint(bertscore(preds, target))
{'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> metric = BERTScore()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/bert_score-1.png
>>> # Example plotting multiple values
>>> from torch import tensor
>>> from torchmetrics.text.bert import BERTScore
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> metric = BERTScore()
>>> values = []
>>> for _ in range(10):
...     val = metric(preds, target)
...     val = {k: tensor(v).mean() for k,v in val.items()}  # convert into single value per key
...     values.append(val)
>>> fig_, ax_ = metric.plot(values)
_images/bert_score-2.png

Functional Interface

torchmetrics.functional.text.bert.bert_score(preds, target, model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None)[source]

Bert_score Evaluating Text Generation for text similirity matching.

This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks.

This implemenation follows the original implementation from BERT_score.

Parameters:
  • preds (Union[str, Sequence[str], Dict[str, Tensor]]) – Either an iterable of predicted sentences or a Dict[input_ids, attention_mask].

  • target (Union[str, Sequence[str], Dict[str, Tensor]]) – Either an iterable of target sentences or a Dict[input_ids, attention_mask].

  • model_name_or_path (Optional[str]) – A name or a model path used to load transformers pretrained model.

  • num_layers (Optional[int]) – A layer of representation to use.

  • all_layers (bool) – An indication of whether the representation from all model’s layers should be used. If all_layers = True, the argument num_layers is ignored.

  • model (Optional[Module]) – A user’s own model.

  • user_tokenizer (Optional[Any]) – A user’s own tokenizer used with the own model. This must be an instance with the __call__ method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing "input_ids" and "attention_mask" represented by Tensor. It is up to the user’s model of whether "input_ids" is a Tensor of input ids or embedding vectors. his tokenizer must prepend an equivalent of [CLS] token and append an equivalent of [SEP] token as transformers tokenizer does.

  • user_forward_fn (Optional[Callable[[Module, Dict[str, Tensor]], Tensor]]) – A user’s own forward function used in a combination with user_model. This function must take user_model and a python dictionary of containing "input_ids" and "attention_mask" represented by Tensor as an input and return the model’s output represented by the single Tensor.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.

  • idf (bool) – An indication of whether normalization using inverse document frequencies should be used.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (int) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • return_hash (bool) – An indication of whether the correspodning hash_code should be returned.

  • lang (str) – A language of input sentences. It is used when the scores are rescaled with a baseline.

  • rescale_with_baseline (bool) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model from transformers model is used, the corresponding baseline is downloaded from the original bert-score package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score

  • baseline_path (Optional[str]) – A path to the user’s own local csv/tsv file with the baseline scale.

  • baseline_url (Optional[str]) – A url path to the user’s own csv/tsv file with the baseline scale.

Return type:

Dict[str, Union[Tensor, List[float], str]]

Returns:

Python dictionary containing the keys precision, recall and f1 with corresponding values.

Raises:

Example

>>> from pprint import pprint
>>> from torchmetrics.functional.text.bert import bert_score
>>> preds = ["hello there", "general kenobi"]
>>> target = ["hello there", "master kenobi"]
>>> pprint(bert_score(preds, target))
{'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}

BLEU Score

Module Interface

class torchmetrics.text.BLEUScore(n_gram=4, smooth=False, weights=None, **kwargs)[source]

Calculate BLEU score of machine translated text with one or more references.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of machine translated corpus

  • target (Sequence): An iterable of iterables of reference corpus

As output of forward and update the metric returns the following output:

  • bleu (Tensor): A tensor with the BLEU Score

Parameters:
Raises:

ValueError – If a length of a list of weights is not None and not equal to n_gram.

Example

>>> from torchmetrics.text import BLEUScore
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> bleu = BLEUScore()
>>> bleu(preds, target)
tensor(0.7598)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import BLEUScore
>>> metric = BLEUScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/bleu_score-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import BLEUScore
>>> metric = BLEUScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/bleu_score-2.png

Functional Interface

torchmetrics.functional.text.bleu_score(preds, target, n_gram=4, smooth=False, weights=None)[source]

Calculate BLEU score of machine translated text with one or more references.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of machine translated corpus

  • target (Sequence[Union[str, Sequence[str]]]) – An iterable of iterables of reference corpus

  • n_gram (int) – Gram value ranged from 1 to 4

  • smooth (bool) – Whether to apply smoothing - see [2]

  • weights (Optional[Sequence[float]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.

Return type:

Tensor

Returns:

Tensor with BLEU Score

Raises:
  • ValueError – If preds and target corpus have different lengths.

  • ValueError – If a length of a list of weights is not None and not equal to n_gram.

Example

>>> from torchmetrics.functional.text import bleu_score
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> bleu_score(preds, target)
tensor(0.7598)

References

[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu BLEU

[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution

Char Error Rate

Module Interface

class torchmetrics.text.CharErrorRate(**kwargs)[source]

Character Error Rate (CER) is a metric of the performance of an automatic speech recognition (ASR) system.

This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CharErrorRate of 0 being a perfect score. Character error rate can then be computed as:

\[CharErrorRate = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}\]
where:
  • \(S\) is the number of substitutions,

  • \(D\) is the number of deletions,

  • \(I\) is the number of insertions,

  • \(C\) is the number of correct characters,

  • \(N\) is the number of characters in the reference (N=S+D+C).

Compute CharErrorRate score of transcribed segments against references.

As input to forward and update the metric accepts the following input:

  • preds (str): Transcription(s) to score as a string or list of strings

  • target (str): Reference(s) for each speech input as a string or list of strings

As output of forward and compute the metric returns the following output:

  • cer (Tensor): A tensor with the Character Error Rate score

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torchmetrics.text import CharErrorRate
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> cer = CharErrorRate()
>>> cer(preds, target)
tensor(0.3415)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import CharErrorRate
>>> metric = CharErrorRate()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/char_error_rate-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import CharErrorRate
>>> metric = CharErrorRate()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/char_error_rate-2.png

Functional Interface

torchmetrics.functional.text.char_error_rate(preds, target)[source]

Compute Character Rrror Rate used for performance of an automatic speech recognition system.

This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CER of 0 being a perfect score.

Parameters:
  • preds (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • target (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type:

Tensor

Returns:

Character error rate score

Examples

>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> char_error_rate(preds=preds, target=target)
tensor(0.3415)

ChrF Score

Module Interface

class torchmetrics.text.CHRFScore(n_char_order=6, n_word_order=2, beta=2.0, lowercase=False, whitespace=False, return_sentence_level_score=False, **kwargs)[source]

Calculate chrf score of machine translated text with one or more references.

This implementation supports both ChrF score computation introduced in chrF score and chrF++ score introduced in chrF++ score. This implementation follows the implmenetaions from https://github.com/m-popovic/chrF and https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of hypothesis corpus

  • target (Sequence): An iterable of iterables of reference corpus

As output of forward and compute the metric returns the following output:

  • chrf (Tensor): If return_sentence_level_score=True return a list of sentence-level chrF/chrF++ scores, else return a corpus-level chrF/chrF++ score

Parameters:
  • n_char_order (int) – A character n-gram order. If n_char_order=6, the metrics refers to the official chrF/chrF++.

  • n_word_order (int) – A word n-gram order. If n_word_order=2, the metric refers to the official chrF++. If n_word_order=0, the metric is equivalent to the original ChrF.

  • beta (float) – parameter determining an importance of recall w.r.t. precision. If beta=1, their importance is equal.

  • lowercase (bool) – An indication whether to enable case-insesitivity.

  • whitespace (bool) – An indication whether keep whitespaces during n-gram extraction.

  • return_sentence_level_score (bool) – An indication whether a sentence-level chrF/chrF++ score to be returned.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ValueError – If n_char_order is not an integer greater than or equal to 1.

  • ValueError – If n_word_order is not an integer greater than or equal to 0.

  • ValueError – If beta is smaller than 0.

Example

>>> from torchmetrics.text import CHRFScore
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> chrf = CHRFScore()
>>> chrf(preds, target)
tensor(0.8640)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import CHRFScore
>>> metric = CHRFScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/chrf_score-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import CHRFScore
>>> metric = CHRFScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/chrf_score-2.png

Functional Interface

torchmetrics.functional.text.chrf_score(preds, target, n_char_order=6, n_word_order=2, beta=2.0, lowercase=False, whitespace=False, return_sentence_level_score=False)[source]

Calculate chrF score of machine translated text with one or more references.

This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in chrF++ score. This implementation follows the implmenetaions from https://github.com/m-popovic/chrF and https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of hypothesis corpus.

  • target (Sequence[Union[str, Sequence[str]]]) – An iterable of iterables of reference corpus.

  • n_char_order (int) – A character n-gram order. If n_char_order=6, the metrics refers to the official chrF/chrF++.

  • n_word_order (int) – A word n-gram order. If n_word_order=2, the metric refers to the official chrF++. If n_word_order=0, the metric is equivalent to the original chrF.

  • beta (float) – A parameter determining an importance of recall w.r.t. precision. If beta=1, their importance is equal.

  • lowercase (bool) – An indication whether to enable case-insesitivity.

  • whitespace (bool) – An indication whether to keep whitespaces during character n-gram extraction.

  • return_sentence_level_score (bool) – An indication whether a sentence-level chrF/chrF++ score to be returned.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

A corpus-level chrF/chrF++ score. (Optionally) A list of sentence-level chrF/chrF++ scores if return_sentence_level_score=True.

Raises:
  • ValueError – If n_char_order is not an integer greater than or equal to 1.

  • ValueError – If n_word_order is not an integer greater than or equal to 0.

  • ValueError – If beta is smaller than 0.

Example

>>> from torchmetrics.functional.text import chrf_score
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> chrf_score(preds, target)
tensor(0.8640)

References

[1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović chrF score

[2] chrF++: words helping character n-grams by Maja Popović chrF++ score

Edit Distance

Module Interface

class torchmetrics.text.EditDistance(substitution_cost=1, reduction='mean', **kwargs)[source]

Calculates the Levenshtein edit distance between two sequences.

The edit distance is the number of characters that need to be substituted, inserted, or deleted, to transform the predicted text into the reference text. The lower the distance, the more accurate the model is considered to be.

Implementation is similar to nltk.edit_distance.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of hypothesis corpus

  • target (Sequence): An iterable of iterables of reference corpus

As output of forward and compute the metric returns the following output:

  • eed (Tensor): A tensor with the extended edit distance score. If reduction is set to 'none' or None, this has shape (N, ), where N is the batch size. Otherwise, this is a scalar.

Parameters:
  • substitution_cost (int) – The cost of substituting one character for another.

  • reduction (Optional[Literal['mean', 'sum', 'none']]) –

    a method to reduce metric score over samples.

    • 'mean': takes the mean over samples

    • 'sum': takes the sum over samples

    • None or 'none': return the score per sample

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example::

Basic example with two strings. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits:

>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance()
>>> metric(["rain"], ["shine"])
tensor(3.)
Example::

Basic example with two strings and substitution cost of 2. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits, where two of them are substitutions:

>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance(substitution_cost=2)
>>> metric(["rain"], ["shine"])
tensor(5.)
Example::

Multiple strings example:

>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance(reduction=None)
>>> metric(["rain", "lnaguaeg"], ["shine", "language"])
tensor([3, 4], dtype=torch.int32)
>>> metric = EditDistance(reduction="mean")
>>> metric(["rain", "lnaguaeg"], ["shine", "language"])
tensor(3.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/edit-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import EditDistance
>>> metric = EditDistance()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/edit-2.png

Functional Interface

torchmetrics.functional.text.edit_distance(preds, target, substitution_cost=1, reduction='mean')[source]

Calculates the Levenshtein edit distance between two sequences.

The edit distance is the number of characters that need to be substituted, inserted, or deleted, to transform the predicted text into the reference text. The lower the distance, the more accurate the model is considered to be.

Implementation is similar to nltk.edit_distance.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of predicted texts (strings).

  • target (Union[str, Sequence[str]]) – An iterable of reference texts (strings).

  • substitution_cost (int) – The cost of substituting one character for another.

  • reduction (Optional[Literal['mean', 'sum', 'none']]) –

    a method to reduce metric score over samples.

    • 'mean': takes the mean over samples

    • 'sum': takes the sum over samples

    • None or 'none': return the score per sample

Raises:
  • ValueError – If preds and target do not have the same length.

  • ValueError – If preds or target contain non-string values.

Return type:

Tensor

Example::

Basic example with two strings. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain"], ["shine"])
tensor(3.)
Example::

Basic example with two strings and substitution cost of 2. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits, where two of them are substitutions:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain"], ["shine"], substitution_cost=2)
tensor(5.)
Example::

Multiple strings example:

>>> from torchmetrics.functional.text import edit_distance
>>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction=None)
tensor([3, 4], dtype=torch.int32)
>>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction="mean")
tensor(3.5000)

Extended Edit Distance

Module Interface

class torchmetrics.text.ExtendedEditDistance(language='en', return_sentence_level_score=False, alpha=2.0, rho=0.3, deletion=0.2, insertion=1.0, **kwargs)[source]

Compute extended edit distance score (ExtendedEditDistance) for strings or list of strings.

The metric utilises the Levenshtein distance and extends it by adding a jump operation.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of hypothesis corpus

  • target (Sequence): An iterable of iterables of reference corpus

As output of forward and compute the metric returns the following output:

  • eed (Tensor): A tensor with the extended edit distance score

Parameters:
  • language (Literal['en', 'ja']) – Language used in sentences. Only supports English (en) and Japanese (ja) for now.

  • return_sentence_level_score (bool) – An indication of whether sentence-level EED score is to be returned

  • alpha (float) – optimal jump penalty, penalty for jumps between characters

  • rho (float) – coverage cost, penalty for repetition of characters

  • deletion (float) – penalty for deletion of character

  • insertion (float) – penalty for insertion or substitution of character

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.text import ExtendedEditDistance
>>> preds = ["this is the prediction", "here is an other sample"]
>>> target = ["this is the reference", "here is another one"]
>>> eed = ExtendedEditDistance()
>>> eed(preds=preds, target=target)
tensor(0.3078)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import ExtendedEditDistance
>>> metric = ExtendedEditDistance()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/extended_edit_distance-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import ExtendedEditDistance
>>> metric = ExtendedEditDistance()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/extended_edit_distance-2.png

Functional Interface

torchmetrics.functional.text.extended_edit_distance(preds, target, language='en', return_sentence_level_score=False, alpha=2.0, rho=0.3, deletion=0.2, insertion=1.0)[source]

Compute extended edit distance score (ExtendedEditDistance) [1] for strings or list of strings.

The metric utilises the Levenshtein distance and extends it by adding a jump operation.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of hypothesis corpus.

  • target (Sequence[Union[str, Sequence[str]]]) – An iterable of iterables of reference corpus.

  • language (Literal['en', 'ja']) – Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en

  • return_sentence_level_score (bool) – An indication of whether sentence-level EED score is to be returned.

  • alpha (float) – optimal jump penalty, penalty for jumps between characters

  • rho (float) – coverage cost, penalty for repetition of characters

  • deletion (float) – penalty for deletion of character

  • insertion (float) – penalty for insertion or substitution of character

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

Extended edit distance score as a tensor

Example

>>> from torchmetrics.functional.text import extended_edit_distance
>>> preds = ["this is the prediction", "here is an other sample"]
>>> target = ["this is the reference", "here is another one"]
>>> extended_edit_distance(preds=preds, target=target)
tensor(0.3078)

References

[1] P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”, submitted to WMT 2019. ExtendedEditDistance

InfoLM

Module Interface

class torchmetrics.text.infolm.InfoLM(model_name_or_path='bert-base-uncased', temperature=0.25, information_measure='kl_divergence', idf=True, alpha=None, beta=None, device=None, max_length=None, batch_size=64, num_threads=0, verbose=True, return_sentence_level_score=False, **kwargs)[source]

Calculate InfoLM.

InfoLM measures a distance/divergence between predicted and reference sentence discrete distribution using one of the following information measures:

InfoLM is a family of untrained embedding-based metrics which addresses some famous flaws of standard string-based metrics thanks to the usage of pre-trained masked language models. This family of metrics is mainly designed for summarization and data-to-text tasks.

The implementation of this metric is fully based HuggingFace transformers’ package.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of hypothesis corpus

  • target (Sequence): An iterable of reference corpus

As output of forward and compute the metric returns the following output:

  • infolm (Tensor): If return_sentence_level_score=True return a tuple with a tensor with the corpus-level InfoLM score and a list of sentence-level InfoLM scores, else return a corpus-level InfoLM score

Parameters:
  • model_name_or_path (Union[str, PathLike]) – A name or a model path used to load transformers pretrained model. By default the “bert-base-uncased” model is used.

  • temperature (float) – A temperature for calibrating language modelling. For more information, please reference InfoLM paper.

  • information_measure (Literal['kl_divergence', 'alpha_divergence', 'beta_divergence', 'ab_divergence', 'renyi_divergence', 'l1_distance', 'l2_distance', 'l_infinity_distance', 'fisher_rao_distance']) – A name of information measure to be used. Please use one of: [‘kl_divergence’, ‘alpha_divergence’, ‘beta_divergence’, ‘ab_divergence’, ‘renyi_divergence’, ‘l1_distance’, ‘l2_distance’, ‘l_infinity_distance’, ‘fisher_rao_distance’]

  • idf (bool) – An indication of whether normalization using inverse document frequencies should be used.

  • alpha (Optional[float]) – Alpha parameter of the divergence used for alpha, AB and Rényi divergence measures.

  • beta (Optional[float]) – Beta parameter of the divergence used for beta and AB divergence measures.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (Optional[int]) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings calculation.

  • return_sentence_level_score (bool) – An indication whether a sentence-level InfoLM score to be returned.

Example

>>> from torchmetrics.text.infolm import InfoLM
>>> preds = ['he read the book because he was interested in world history']
>>> target = ['he was interested in world history because he read the book']
>>> infolm = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
>>> infolm(preds, target)
tensor(-0.1784)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text.infolm import InfoLM
>>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
>>> preds = ['he read the book because he was interested in world history']
>>> target = ['he was interested in world history because he read the book']
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/infolm-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text.infolm import InfoLM
>>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/infolm-2.png

Functional Interface

torchmetrics.functional.text.infolm.infolm(preds, target, model_name_or_path='bert-base-uncased', temperature=0.25, information_measure='kl_divergence', idf=True, alpha=None, beta=None, device=None, max_length=None, batch_size=64, num_threads=0, verbose=True, return_sentence_level_score=False)[source]

Calculate InfoLM [1].

InfoML corresponds to distance/divergence between predicted and reference sentence discrete distribution using one of the following information measures:

InfoLM is a family of untrained embedding-based metrics which addresses some famous flaws of standard string-based metrics thanks to the usage of pre-trained masked language models. This family of metrics is mainly designed for summarization and data-to-text tasks.

If you want to use IDF scaling over the whole dataset, please use the class metric.

The implementation of this metric is fully based HuggingFace transformers’ package.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of hypothesis corpus.

  • target (Union[str, Sequence[str]]) – An iterable of reference corpus.

  • model_name_or_path (Union[str, PathLike]) – A name or a model path used to load transformers pretrained model.

  • temperature (float) – A temperature for calibrating language modelling. For more information, please reference InfoLM paper.

  • information_measure (Literal['kl_divergence', 'alpha_divergence', 'beta_divergence', 'ab_divergence', 'renyi_divergence', 'l1_distance', 'l2_distance', 'l_infinity_distance', 'fisher_rao_distance']) – A name of information measure to be used. Please use one of: [‘kl_divergence’, ‘alpha_divergence’, ‘beta_divergence’, ‘ab_divergence’, ‘renyi_divergence’, ‘l1_distance’, ‘l2_distance’, ‘l_infinity_distance’, ‘fisher_rao_distance’]

  • idf (bool) – An indication of whether normalization using inverse document frequencies should be used.

  • alpha (Optional[float]) – Alpha parameter of the divergence used for alpha, AB and Rényi divergence measures.

  • beta (Optional[float]) – Beta parameter of the divergence used for beta and AB divergence measures.

  • device (Union[str, device, None]) – A device to be used for calculation.

  • max_length (Optional[int]) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.

  • batch_size (int) – A batch size used for model processing.

  • num_threads (int) – A number of threads to use for a dataloader.

  • verbose (bool) – An indication of whether a progress bar to be displayed during the embeddings calculation.

  • return_sentence_level_score (bool) – An indication whether a sentence-level InfoLM score to be returned.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

A corpus-level InfoLM score. (Optionally) A list of sentence-level InfoLM scores if return_sentence_level_score=True.

Example

>>> from torchmetrics.functional.text.infolm import infolm
>>> preds = ['he read the book because he was interested in world history']
>>> target = ['he was interested in world history because he read the book']
>>> infolm(preds, target, model_name_or_path='google/bert_uncased_L-2_H-128_A-2', idf=False)
tensor(-0.1784)

References

[1] InfoLM: A New Metric to Evaluate Summarization & Data2Text Generation by Pierre Colombo, Chloé Clavel and Pablo Piantanida InfoLM

Match Error Rate

Module Interface

class torchmetrics.text.MatchErrorRate(**kwargs)[source]

Match Error Rate (MER) is a common metric of the performance of an automatic speech recognition system.

This value indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the better the performance of the ASR system with a MatchErrorRate of 0 being a perfect score. Match error rate can then be computed as:

\[mer = \frac{S + D + I}{N + I} = \frac{S + D + I}{S + D + C + I}\]
where:
  • \(S\) is the number of substitutions,

  • \(D\) is the number of deletions,

  • \(I\) is the number of insertions,

  • \(C\) is the number of correct words,

  • \(N\) is the number of words in the reference (\(N=S+D+C\)).

As input to forward and update the metric accepts the following input:

  • preds (List): Transcription(s) to score as a string or list of strings

  • target (List): Reference(s) for each speech input as a string or list of strings

As output of forward and compute the metric returns the following output:

  • mer (Tensor): A tensor with the match error rate

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torchmetrics.text import MatchErrorRate
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> mer = MatchErrorRate()
>>> mer(preds, target)
tensor(0.4444)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import MatchErrorRate
>>> metric = MatchErrorRate()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/match_error_rate-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import MatchErrorRate
>>> metric = MatchErrorRate()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/match_error_rate-2.png

Functional Interface

torchmetrics.functional.text.match_error_rate(preds, target)[source]

Match error rate is a metric of the performance of an automatic speech recognition system.

This value indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the better the performance of the ASR system with a MatchErrorRate of 0 being a perfect score.

Parameters:
  • preds (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • target (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type:

Tensor

Returns:

Match error rate score

Examples

>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> match_error_rate(preds=preds, target=target)
tensor(0.4444)

Perplexity

Module Interface

class torchmetrics.text.perplexity.Perplexity(ignore_index=None, **kwargs)[source]

Perplexity measures how well a language model predicts a text sample.

It’s calculated as the average number of bits per word a model needs to represent the sample.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Logits or a unnormalized score assigned to each token in a sequence with shape

    [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.

  • target (Tensor): Ground truth values with a shape [batch_size, seq_len]

As output of forward and compute the metric returns the following output:

  • perp (Tensor): A tensor with the perplexity score

Parameters:
  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.

  • kwargs (Dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torchmetrics.text import Perplexity
>>> import torch
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand(2, 8, 5, generator=gen)
>>> target = torch.randint(5, (2, 8), generator=gen)
>>> target[0, 6:] = -100
>>> perp = Perplexity(ignore_index=-100)
>>> perp(preds, target)
tensor(5.8540)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.text import Perplexity
>>> metric = Perplexity()
>>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))
>>> fig_, ax_ = metric.plot()
_images/perplexity-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.text import Perplexity
>>> metric = Perplexity()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8))))
>>> fig_, ax_ = metric.plot(values)
_images/perplexity-2.png

Functional Interface

torchmetrics.functional.text.perplexity.perplexity(preds, target, ignore_index=None)[source]

Perplexity measures how well a language model predicts a text sample.

This metric is calculated as the average number of bits per word a model needs to represent the sample.

Parameters:
  • preds (Tensor) – Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.

  • target (Tensor) – Ground truth values with a shape [batch_size, seq_len].

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.

Return type:

Tensor

Returns:

Perplexity value

Examples

>>> import torch
>>> gen = torch.manual_seed(42)
>>> preds = torch.rand(2, 8, 5, generator=gen)
>>> target = torch.randint(5, (2, 8), generator=gen)
>>> target[0, 6:] = -100
>>> perplexity(preds, target, ignore_index=-100)
tensor(5.8540)

ROUGE Score

Module Interface

class torchmetrics.text.rouge.ROUGEScore(use_stemmer=False, normalizer=None, tokenizer=None, accumulate='best', rouge_keys=('rouge1', 'rouge2', 'rougeL', 'rougeLsum'), **kwargs)[source]

Calculate Rouge Score, used for automatic summarization.

This implementation should imitate the behaviour of the rouge-score package Python ROUGE Implementation

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of predicted sentences or a single predicted sentence

  • target (Sequence): An iterable of target sentences or an iterable of interables of target sentences or a single target sentence

As output of forward and compute the metric returns the following output:

  • rouge (Dict): A dictionary of tensor rouge scores for each input str rouge key

Parameters:
  • use_stemmer (bool) – Use Porter stemmer to strip word suffixes to improve matching.

  • normalizer (Optional[Callable[[str], str]]) – A user’s own normalizer function. If this is None, replacing any non-alpha-numeric characters with spaces is default. This function must take a str and return a str.

  • tokenizer (Optional[Callable[[str], Sequence[str]]]) – A user’s own tokenizer function. If this is None, spliting by spaces is default This function must take a str and return Sequence[str]

  • accumulate (Literal['avg', 'best']) –

    Useful in case of multi-reference rouge score.

    • avg takes the avg of all references with respect to predictions

    • best takes the best fmeasure score obtained between prediction and multiple corresponding references.

  • rouge_keys (Union[str, Tuple[str, ...]]) – A list of rouge types to calculate. Keys that are allowed are rougeL, rougeLsum, and rouge1 through rouge9.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.text.rouge import ROUGEScore
>>> preds = "My name is John"
>>> target = "Is your name John"
>>> rouge = ROUGEScore()
>>> from pprint import pprint
>>> pprint(rouge(preds, target))
{'rouge1_fmeasure': tensor(0.7500),
 'rouge1_precision': tensor(0.7500),
 'rouge1_recall': tensor(0.7500),
 'rouge2_fmeasure': tensor(0.),
 'rouge2_precision': tensor(0.),
 'rouge2_recall': tensor(0.),
 'rougeL_fmeasure': tensor(0.5000),
 'rougeL_precision': tensor(0.5000),
 'rougeL_recall': tensor(0.5000),
 'rougeLsum_fmeasure': tensor(0.5000),
 'rougeLsum_precision': tensor(0.5000),
 'rougeLsum_recall': tensor(0.5000)}
Raises:
  • ValueError – If the python packages nltk is not installed.

  • ValueError – If any of the rouge_keys does not belong to the allowed set of keys.

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text.rouge import ROUGEScore
>>> metric = ROUGEScore()
>>> preds = "My name is John"
>>> target = "Is your name John"
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/rouge_score-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text.rouge import ROUGEScore
>>> metric = ROUGEScore()
>>> preds = "My name is John"
>>> target = "Is your name John"
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/rouge_score-2.png

Functional Interface

torchmetrics.functional.text.rouge.rouge_score(preds, target, accumulate='best', use_stemmer=False, normalizer=None, tokenizer=None, rouge_keys=('rouge1', 'rouge2', 'rougeL', 'rougeLsum'))[source]

Calculate Calculate Rouge Score , used for automatic summarization.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of predicted sentences or a single predicted sentence.

  • target (Union[str, Sequence[str], Sequence[Sequence[str]]]) – An iterable of iterables of target sentences or an iterable of target sentences or a single target sentence.

  • accumulate (Literal['avg', 'best']) –

    Useful incase of multi-reference rouge score.

    • avg takes the avg of all references with respect to predictions

    • best takes the best fmeasure score obtained between prediction and multiple corresponding references.

  • use_stemmer (bool) – Use Porter stemmer to strip word suffixes to improve matching.

  • normalizer (Optional[Callable[[str], str]]) – A user’s own normalizer function. If this is None, replacing any non-alpha-numeric characters with spaces is default. This function must take a str and return a str.

  • tokenizer (Optional[Callable[[str], Sequence[str]]]) – A user’s own tokenizer function. If this is None, spliting by spaces is default This function must take a str and return Sequence[str]

  • rouge_keys (Union[str, Tuple[str, ...]]) – A list of rouge types to calculate. Keys that are allowed are rougeL, rougeLsum, and rouge1 through rouge9.

Return type:

Dict[str, Tensor]

Returns:

Python dictionary of rouge scores for each input rouge key.

Example

>>> from torchmetrics.functional.text.rouge import rouge_score
>>> preds = "My name is John"
>>> target = "Is your name John"
>>> from pprint import pprint
>>> pprint(rouge_score(preds, target))
{'rouge1_fmeasure': tensor(0.7500),
 'rouge1_precision': tensor(0.7500),
 'rouge1_recall': tensor(0.7500),
 'rouge2_fmeasure': tensor(0.),
 'rouge2_precision': tensor(0.),
 'rouge2_recall': tensor(0.),
 'rougeL_fmeasure': tensor(0.5000),
 'rougeL_precision': tensor(0.5000),
 'rougeL_recall': tensor(0.5000),
 'rougeLsum_fmeasure': tensor(0.5000),
 'rougeLsum_precision': tensor(0.5000),
 'rougeLsum_recall': tensor(0.5000)}
Raises:
  • ModuleNotFoundError – If the python package nltk is not installed.

  • ValueError – If any of the rouge_keys does not belong to the allowed set of keys.

References

[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin. https://aclanthology.org/W04-1013/

Sacre BLEU Score

Module Interface

class torchmetrics.text.SacreBLEUScore(n_gram=4, smooth=False, tokenize='13a', lowercase=False, weights=None, **kwargs)[source]

Calculate BLEU score of machine translated text with one or more references.

This implementation follows the behaviour of SacreBLEU. The SacreBLEU implementation differs from the NLTK BLEU implementation in tokenization techniques.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of machine translated corpus

  • target (Sequence): An iterable of iterables of reference corpus

As output of forward and compute the metric returns the following output:

  • sacre_bleu (Tensor): A tensor with the SacreBLEU Score

Parameters:
  • n_gram (int) – Gram value ranged from 1 to 4

  • smooth (bool) – Whether to apply smoothing, see SacreBLEU

  • tokenize (Literal['none', '13a', 'zh', 'intl', 'char']) – Tokenization technique to be used. Supported tokenization: ['none', '13a', 'zh', 'intl', 'char']

  • lowercase (bool) – If True, BLEU score over lowercased text is calculated.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

  • weights (Optional[Sequence[float]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.

Raises:
  • ValueError – If tokenize not one of ‘none’, ‘13a’, ‘zh’, ‘intl’ or ‘char’

  • ValueError – If tokenize is set to ‘intl’ and regex is not installed

  • ValueError – If a length of a list of weights is not None and not equal to n_gram.

Example

>>> from torchmetrics.text import SacreBLEUScore
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> sacre_bleu = SacreBLEUScore()
>>> sacre_bleu(preds, target)
tensor(0.7598)

Additional References:

  • Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import SacreBLEUScore
>>> metric = SacreBLEUScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/sacre_bleu_score-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import SacreBLEUScore
>>> metric = SacreBLEUScore()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/sacre_bleu_score-2.png

Functional Interface

torchmetrics.functional.text.sacre_bleu_score(preds, target, n_gram=4, smooth=False, tokenize='13a', lowercase=False, weights=None)[source]

Calculate BLEU score [1] of machine translated text with one or more references.

This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu.

Parameters:
  • preds (Sequence[str]) – An iterable of machine translated corpus

  • target (Sequence[Sequence[str]]) – An iterable of iterables of reference corpus

  • n_gram (int) – Gram value ranged from 1 to 4

  • smooth (bool) – Whether to apply smoothing - see [2]

  • tokenize (Literal['none', '13a', 'zh', 'intl', 'char']) – Tokenization technique to be used. Supported tokenization: [‘none’, ‘13a’, ‘zh’, ‘intl’, ‘char’]

  • lowercase (bool) – If True, BLEU score over lowercased text is calculated.

  • weights (Optional[Sequence[float]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.

Return type:

Tensor

Returns:

Tensor with BLEU Score

Raises:
  • ValueError – If preds and target corpus have different lengths.

  • ValueError – If a length of a list of weights is not None and not equal to n_gram.

Example

>>> from torchmetrics.functional.text import sacre_bleu_score
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> sacre_bleu_score(preds, target)
tensor(0.7598)

References

[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu BLEU

[2] A Call for Clarity in Reporting BLEU Scores by Matt Post.

[3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution

SQuAD

Module Interface

class torchmetrics.text.SQuAD(**kwargs)[source]

Calculate SQuAD Metric which is a metric for evaluating question answering models.

This metric corresponds to the scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).

As input to forward and update the metric accepts the following input:

  • preds (Dict): A Dictionary or List of Dictionary-s that map id and prediction_text to the respective values

    Example prediction:

    {"prediction_text": "TorchMetrics is awesome", "id": "123"}
    
  • target (Dict): A Dictionary or List of Dictionary-s that contain the answers and id in the SQuAD Format.

    Example target:

    {
        'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
        'id': '1',
    }
    

    Reference SQuAD Format:

    {
        'answers': {'answer_start': [1], 'text': ['This is a test text']},
        'context': 'This is a test context.',
        'id': '1',
        'question': 'Is this a test?',
        'title': 'train test'
    }
    

As output of forward and compute the metric returns the following output:

  • squad (Dict): A dictionary containing the F1 score (key: “f1”),

    and Exact match score (key: “exact_match”) for the batch.

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.text import SQuAD
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
>>> squad = SQuAD()
>>> squad(preds, target)
{'exact_match': tensor(100.), 'f1': tensor(100.)}
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import SQuAD
>>> metric = SQuAD()
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/squad-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import SQuAD
>>> metric = SQuAD()
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/squad-2.png

Functional Interface

torchmetrics.functional.text.squad(preds, target)[source]

Calculate SQuAD Metric .

Parameters:
  • preds (Union[Dict[str, str], List[Dict[str, str]]]) –

    A Dictionary or List of Dictionary-s that map id and prediction_text to the respective values.

    Example prediction:

    {"prediction_text": "TorchMetrics is awesome", "id": "123"}
    

  • target (Union[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]], List[Dict[str, Union[str, Dict[str, Union[List[str], List[int]]]]]]]) –

    A Dictionary or List of Dictionary-s that contain the answers and id in the SQuAD Format.

    Example target:

    {
        'answers': [{'answer_start': [1], 'text': ['This is a test answer']}],
        'id': '1',
    }
    

    Reference SQuAD Format:

    {
        'answers': {'answer_start': [1], 'text': ['This is a test text']},
        'context': 'This is a test context.',
        'id': '1',
        'question': 'Is this a test?',
        'title': 'train test'
    }
    

Return type:

Dict[str, Tensor]

Returns:

Dictionary containing the F1 score, Exact match score for the batch.

Example

>>> from torchmetrics.functional.text.squad import squad
>>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}]
>>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}]
>>> squad(preds, target)
{'exact_match': tensor(100.), 'f1': tensor(100.)}
Raises:

KeyError – If the required keys are missing in either predictions or targets.

References

[1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang SQuAD Metric .

Translation Edit Rate (TER)

Module Interface

class torchmetrics.text.TranslationEditRate(normalize=False, no_punctuation=False, lowercase=True, asian_support=False, return_sentence_level_score=False, **kwargs)[source]

Calculate Translation edit rate (TER) of machine translated text with one or more references.

This implementation follows the one from SacreBleu_ter, which is a near-exact reimplementation of the Tercom algorithm, produces identical results on all “sane” outputs.

As input to forward and update the metric accepts the following input:

  • preds (Sequence): An iterable of hypothesis corpus

  • target (Sequence): An iterable of iterables of reference corpus

As output of forward and compute the metric returns the following output:

  • ter (Tensor): if return_sentence_level_score=True return a corpus-level translation edit rate with a list of sentence-level translation_edit_rate, else return a corpus-level translation edit rate

Parameters:
  • normalize (bool) – An indication whether a general tokenization to be applied.

  • no_punctuation (bool) – An indication whteher a punctuation to be removed from the sentences.

  • lowercase (bool) – An indication whether to enable case-insesitivity.

  • asian_support (bool) – An indication whether asian characters to be processed.

  • return_sentence_level_score (bool) – An indication whether a sentence-level TER to be returned.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> from torchmetrics.text import TranslationEditRate
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> ter = TranslationEditRate()
>>> ter(preds, target)
tensor(0.1538)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import TranslationEditRate
>>> metric = TranslationEditRate()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/translation_edit_rate-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import TranslationEditRate
>>> metric = TranslationEditRate()
>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/translation_edit_rate-2.png

Functional Interface

torchmetrics.functional.text.translation_edit_rate(preds, target, normalize=False, no_punctuation=False, lowercase=True, asian_support=False, return_sentence_level_score=False)[source]

Calculate Translation edit rate (TER) of machine translated text with one or more references.

This implementation follows the implmenetaions from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/ter.py. The sacrebleu implmenetation is a near-exact reimplementation of the Tercom algorithm, produces identical results on all “sane” outputs.

Parameters:
  • preds (Union[str, Sequence[str]]) – An iterable of hypothesis corpus.

  • target (Sequence[Union[str, Sequence[str]]]) – An iterable of iterables of reference corpus.

  • normalize (bool) – An indication whether a general tokenization to be applied.

  • no_punctuation (bool) – An indication whteher a punctuation to be removed from the sentences.

  • lowercase (bool) – An indication whether to enable case-insesitivity.

  • asian_support (bool) – An indication whether asian characters to be processed.

  • return_sentence_level_score (bool) – An indication whether a sentence-level TER to be returned.

Return type:

Union[Tensor, Tuple[Tensor, List[Tensor]]]

Returns:

A corpus-level translation edit rate (TER). (Optionally) A list of sentence-level translation_edit_rate (TER) if return_sentence_level_score=True.

Example

>>> preds = ['the cat is on the mat']
>>> target = [['there is a cat on the mat', 'a cat is on the mat']]
>>> translation_edit_rate(preds, target)
tensor(0.1538)

References

[1] A Study of Translation Edit Rate with Targeted Human Annotation by Mathew Snover, Bonnie Dorr, Richard Schwartz, Linnea Micciulla and John Makhoul TER

Word Error Rate

Module Interface

class torchmetrics.text.WordErrorRate(**kwargs)[source]

Word error rate (WordErrorRate) is a common metric of the performance of an automatic speech recognition.

This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score. Word error rate can then be computed as:

\[WER = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}\]

where: - \(S\) is the number of substitutions, - \(D\) is the number of deletions, - \(I\) is the number of insertions, - \(C\) is the number of correct words, - \(N\) is the number of words in the reference (\(N=S+D+C\)).

Compute WER score of transcribed segments against references.

As input to forward and update the metric accepts the following input:

  • preds (List): Transcription(s) to score as a string or list of strings

  • target (List): Reference(s) for each speech input as a string or list of strings

As output of forward and compute the metric returns the following output:

  • wer (Tensor): A tensor with the Word Error Rate score

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torchmetrics.text import WordErrorRate
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> wer = WordErrorRate()
>>> wer(preds, target)
tensor(0.5000)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import WordErrorRate
>>> metric = WordErrorRate()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/word_error_rate-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import WordErrorRate
>>> metric = WordErrorRate()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/word_error_rate-2.png

Functional Interface

torchmetrics.functional.text.word_error_rate(preds, target)[source]

Word error rate (WordErrorRate) is a common metric of the performance of an automatic speech recognition system.

This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score.

Parameters:
  • preds (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • target (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type:

Tensor

Returns:

Word error rate score

Examples

>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> word_error_rate(preds=preds, target=target)
tensor(0.5000)

Word Info. Lost

Module Interface

class torchmetrics.text.WordInfoLost(**kwargs)[source]

Word Information Lost (WIL) is a metric of the performance of an automatic speech recognition system.

This value indicates the percentage of words that were incorrectly predicted between a set of ground-truth sentences and a set of hypothesis sentences. The lower the value, the better the performance of the ASR system with a WordInfoLost of 0 being a perfect score. Word Information Lost rate can then be computed as:

\[wil = 1 - \frac{C}{N} + \frac{C}{P}\]

where:

  • \(C\) is the number of correct words,

  • \(N\) is the number of words in the reference

  • \(P\) is the number of words in the prediction

As input to forward and update the metric accepts the following input:

  • preds (List): Transcription(s) to score as a string or list of strings

  • target (List): Reference(s) for each speech input as a string or list of strings

As output of forward and compute the metric returns the following output:

  • wil (Tensor): A tensor with the Word Information Lost score

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torchmetrics.text import WordInfoLost
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> wil = WordInfoLost()
>>> wil(preds, target)
tensor(0.6528)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import WordInfoLost
>>> metric = WordInfoLost()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/word_info_lost-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import WordInfoLost
>>> metric = WordInfoLost()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/word_info_lost-2.png

Functional Interface

torchmetrics.functional.text.word_information_lost(preds, target)[source]

Word Information Lost rate is a metric of the performance of an automatic speech recognition system.

This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score.

Parameters:
  • preds (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • target (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type:

Tensor

Returns:

Word Information Lost rate

Examples

>>> from torchmetrics.functional.text import word_information_lost
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> word_information_lost(preds, target)
tensor(0.6528)

Word Info. Preserved

Module Interface

class torchmetrics.text.WordInfoPreserved(**kwargs)[source]

Word Information Preserved (WIP) is a metric of the performance of an automatic speech recognition system.

This value indicates the percentage of words that were correctly predicted between a set of ground- truth sentences and a set of hypothesis sentences. The higher the value, the better the performance of the ASR system with a WordInfoPreserved of 1 being a perfect score. Word Information Preserved rate can then be computed as:

\[wip = \frac{C}{N} + \frac{C}{P}\]

where:

  • \(C\) is the number of correct words,

  • \(N\) is the number of words in the reference

  • \(P\) is the number of words in the prediction

As input to forward and update the metric accepts the following input:

  • preds (List): Transcription(s) to score as a string or list of strings

  • target (List): Reference(s) for each speech input as a string or list of strings

As output of forward and compute the metric returns the following output:

  • wip (Tensor): A tensor with the Word Information Preserved score

Parameters:

kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torchmetrics.text import WordInfoPreserved
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> wip = WordInfoPreserved()
>>> wip(preds, target)
tensor(0.3472)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torchmetrics.text import WordInfoPreserved
>>> metric = WordInfoPreserved()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
_images/word_info_preserved-1.png
>>> # Example plotting multiple values
>>> from torchmetrics.text import WordInfoPreserved
>>> metric = WordInfoPreserved()
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
_images/word_info_preserved-2.png

Functional Interface

torchmetrics.functional.text.word_information_preserved(preds, target)[source]

Word Information Preserved rate is a metric of the performance of an automatic speech recognition system.

This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information preserved rate of 0 being a perfect score.

Parameters:
  • preds (Union[str, List[str]]) – Transcription(s) to score as a string or list of strings

  • target (Union[str, List[str]]) – Reference(s) for each speech input as a string or list of strings

Return type:

Tensor

Returns:

Word Information preserved rate

Examples

>>> from torchmetrics.functional.text import word_information_preserved
>>> preds = ["this is the prediction", "there is an other sample"]
>>> target = ["this is the reference", "there is another one"]
>>> word_information_preserved(preds, target)
tensor(0.3472)

Bootstrapper

Module Interface

class torchmetrics.wrappers.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', **kwargs)[source]

Using Turn a Metric into a Bootstrapped.

That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever update or forward is called, all input tensors are resampled (with replacement) along the first dimension.

Parameters:
  • base_metric (Metric) – base metric class to wrap

  • num_bootstraps (int) – number of copies to make of the base metric for bootstrapping

  • mean (bool) – if True return the mean of the bootstraps

  • std (bool) – if True return the standard diviation of the bootstraps

  • quantile (Union[float, Tensor, None]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higher

  • raw (bool) – if True, return all bootstrapped values

  • sampling_strategy (str) – Determines how to produce bootstrapped samplings. Either 'poisson' or multinomial. If 'possion' is chosen, the number of times each sample will be included in the bootstrap will be given by \(n\sim Poisson(\lambda=1)\), which approximates the true bootstrap distribution when the number of samples is large. If 'multinomial' is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example::
>>> from pprint import pprint
>>> from torchmetrics.wrappers import BootStrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> _ = torch.manual_seed(123)
>>> base_metric = MulticlassAccuracy(num_classes=5, average='micro')
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> pprint(output)
{'mean': tensor(0.2205), 'std': tensor(0.0859)}
compute()[source]

Compute the bootstrapped metric values.

Always returns a dict of tensors, which can contain the following keys: mean, std, quantile and raw depending on how the class was initialized.

Return type:

Dict[str, Tensor]

forward(*args, **kwargs)[source]

Use the original forward method of the base metric class.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import BootStrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
>>> metric.update(torch.randn(100,), torch.randn(100,))
>>> fig_, ax_ = metric.plot()
_images/bootstrapper-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import BootStrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20)
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randn(100,), torch.randn(100,)))
>>> fig_, ax_ = metric.plot(values)
_images/bootstrapper-2.png
update(*args, **kwargs)[source]

Update the state of the base metric.

Any tensor passed in will be bootstrapped along dimension 0.

Return type:

None

Classwise Wrapper

Module Interface

class torchmetrics.wrappers.ClasswiseWrapper(metric, labels=None, prefix=None, postfix=None)[source]

Wrapper metric for altering the output of classification metrics.

This metric works together with classification metrics that returns multiple values (one value per class) such that label information can be automatically included in the output.

Parameters:
  • metric (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.

  • labels (Optional[List[str]]) – list of strings indicating the different classes.

  • prefix (Optional[str]) – string that is prepended to the metric names.

  • postfix (Optional[str]) – string that is appended to the metric names.

Example::

Basic example where the ouput of a metric is unwrapped into a dictionary with the class index as keys:

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_0': tensor(0.5000),
'multiclassaccuracy_1': tensor(0.7500),
'multiclassaccuracy_2': tensor(0.)}
Example::

Using custom name via prefix and postfix:

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-")
>>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc")
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric_pre(preds, target)  
{'acc-0': tensor(0.5000),
 'acc-1': tensor(0.7500),
 'acc-2': tensor(0.)}
>>> metric_post(preds, target)  
{'0-acc': tensor(0.5000),
 '1-acc': tensor(0.7500),
 '2-acc': tensor(0.)}
Example::

Providing labels as a list of strings:

>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(
...    MulticlassAccuracy(num_classes=3, average=None),
...    labels=["horse", "fish", "dog"]
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.3333),
'multiclassaccuracy_fish': tensor(0.6667),
'multiclassaccuracy_dog': tensor(0.)}
Example::

Classwise can also be used in combination with MetricCollection. In this case, everything will be flattened into a single dictionary:

>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall
>>> labels = ["horse", "fish", "dog"]
>>> metric = MetricCollection(
...     {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels),
...     'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)}
... )
>>> preds = torch.randn(10, 3).softmax(dim=-1)
>>> target = torch.randint(3, (10,))
>>> metric(preds, target)  
{'multiclassaccuracy_horse': tensor(0.),
 'multiclassaccuracy_fish': tensor(0.3333),
 'multiclassaccuracy_dog': tensor(0.4000),
 'multiclassrecall_horse': tensor(0.),
 'multiclassrecall_fish': tensor(0.3333),
 'multiclassrecall_dog': tensor(0.4000)}
compute()[source]

Compute metric.

Return type:

Dict[str, Tensor]

forward(*args, **kwargs)[source]

Calculate on batch and accumulate to global state.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> metric.update(torch.randint(3, (20,)), torch.randint(3, (20,)))
>>> fig_, ax_ = metric.plot()
_images/classwise_wrapper-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import ClasswiseWrapper
>>> from torchmetrics.classification import MulticlassAccuracy
>>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None))
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/classwise_wrapper-2.png
reset()[source]

Reset metric.

Return type:

None

update(*args, **kwargs)[source]

Update state.

Return type:

None

Metric Tracker

Module Interface

class torchmetrics.wrappers.MetricTracker(metric, maximize=True)[source]

A wrapper class that can help keeping track of a metric or metric collection over time.

The wrapper implements the standard .update(), .compute(), .reset() methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:

-MetricTracker.n_steps: number of metrics being tracked -MetricTracker.increment(): initialize a new metric for being tracked -MetricTracker.compute_all(): get the metric value for all steps -MetricTracker.best_metric(): returns the best value

Out of the box, this wrapper class fully supports that the base metric being tracked is a single Metric, a MetricCollection or another MetricWrapper wrapped around a metric. However, multiple layers of nesting, such as using a Metric inside a MetricWrapper inside a MetricCollection is not fully supported, especially the .best_metric method that cannot auto compute the best metric and index for such nested structures.

Parameters:
  • metric (Union[Metric, MetricCollection]) – instance of a torchmetrics.Metric or torchmetrics.MetricCollection to keep track of at each timestep.

  • maximize (Union[bool, List[bool]]) – either single bool or list of bool indicating if higher metric values are better (True) or lower is better (False).

Example (single metric):
>>> from torchmetrics.wrappers import MetricTracker
>>> from torchmetrics.classification import MulticlassAccuracy
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MulticlassAccuracy(num_classes=10, average='micro'))
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randint(10, (100,)), torch.randint(10, (100,))
...         tracker.update(preds, target)
...     print(f"current acc={tracker.compute()}")
current acc=0.1120000034570694
current acc=0.08799999952316284
current acc=0.12600000202655792
current acc=0.07999999821186066
current acc=0.10199999809265137
>>> best_acc, which_epoch = tracker.best_metric(return_step=True)
>>> best_acc  
0.1260...
>>> which_epoch
2
>>> tracker.compute_all()
tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
Example (multiple metrics using MetricCollection):
>>> from torchmetrics.wrappers import MetricTracker
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.regression import MeanSquaredError, ExplainedVariance
>>> _ = torch.manual_seed(42)
>>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True])
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         preds, target = torch.randn(100), torch.randn(100)
...         tracker.update(preds, target)
...     print(f"current stats={tracker.compute()}")  
current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)}
current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)}
current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)}
current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)}
current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)}
>>> from pprint import pprint
>>> best_res, which_epoch = tracker.best_metric(return_step=True)
>>> pprint(best_res)  
{'ExplainedVariance': -0.829...,
 'MeanSquaredError': 1.821...}
>>> which_epoch
{'MeanSquaredError': 0, 'ExplainedVariance': 2}
>>> pprint(tracker.compute_all())
{'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]),
 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}
best_metric(return_step=False)[source]

Return the highest metric out of all tracked.

Parameters:

return_step (bool) – If True will also return the step with the highest metric value.

Return type:

Union[None, float, Tuple[float, int], Tuple[None, None], Dict[str, Optional[float]], Tuple[Dict[str, Optional[float]], Dict[str, Optional[int]]]]

Returns:

Either a single value or a tuple, depends on the value of return_step and the object being tracked.

  • If a single metric is being tracked and return_step=False then a single tensor will be returned

  • If a single metric is being tracked and return_step=True then a 2-element tuple will be returned, where the first value is optimal value and second value is the corresponding optimal step

  • If a metric collection is being tracked and return_step=False then a single dict will be returned, where keys correspond to the different values of the collection and the values are the optimal metric value

  • If a metric collection is being bracked and return_step=True then a 2-element tuple will be returned where each is a dict, with keys corresponding to the different values of th collection and the values of the first dict being the optimal values and the values of the second dict being the optimal step

In addtion the value in all cases may be None if the underlying metric does have a proper defined way of being optimal or in the case where a nested structure of metrics are being tracked.

compute_all()[source]

Compute the metric value for all tracked metrics.

Return type:

Any

Returns:

By default will try stacking the results from all increaments into a single tensor if the tracked base object is a single metric. If a metric collection is provided a dict of stacked tensors will be returned. If the stacking process fails a list of the computed results will be returned.

Raises:

ValueError – If self.increment have not been called before this method is called.

forward(*args, **kwargs)[source]

Call forward of the current metric being tracked.

Return type:

None

increment()[source]

Create a new instance of the input metric that will be updated next.

Return type:

None

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MetricTracker
>>> from torchmetrics.classification import BinaryAccuracy
>>> tracker = MetricTracker(BinaryAccuracy())
>>> for epoch in range(5):
...     tracker.increment()
...     for batch_idx in range(5):
...         tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,)))
>>> fig_, ax_ = tracker.plot()  # plot all epochs
_images/metric_tracker-1.png
reset()[source]

Reset the current metric being tracked.

Return type:

None

reset_all()[source]

Reset all metrics being tracked.

Return type:

None

property n_steps: int

Returns the number of times the tracker has been incremented.

Min / Max

Module Interface

class torchmetrics.wrappers.MinMaxMetric(base_metric, **kwargs)[source]

Wrapper metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.

The min/max value will be updated each time .compute is called.

Parameters:
  • base_metric (Metric) – The metric of which you want to keep track of its maximum and minimum values.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If base_metric` argument is not a subclasses instance of ``torchmetrics.Metric

Example::
>>> import torch
>>> from torchmetrics.wrappers import MinMaxMetric
>>> from torchmetrics.classification import BinaryAccuracy
>>> from pprint import pprint
>>> base_metric = BinaryAccuracy()
>>> minmax_metric = MinMaxMetric(base_metric)
>>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
>>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
>>> labels = torch.Tensor([[0, 1], [0, 1]]).long()
>>> pprint(minmax_metric(preds_1, labels))
{'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
>>> pprint(minmax_metric.compute())
{'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
>>> minmax_metric.update(preds_2, labels)
>>> pprint(minmax_metric.compute())
{'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}
compute()[source]

Compute the underlying metric as well as max and min values for this metric.

Returns a dictionary that consists of the computed value (raw), as well as the minimum (min) and maximum (max) values.

Return type:

Dict[str, Tensor]

forward(*args, **kwargs)[source]

Use the original forward method of the base metric class.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MinMaxMetric
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = MinMaxMetric(BinaryAccuracy())
>>> metric.update(torch.randint(2, (20,)), torch.randint(2, (20,)))
>>> fig_, ax_ = metric.plot()
_images/min_max-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import MinMaxMetric
>>> from torchmetrics.classification import BinaryAccuracy
>>> metric = MinMaxMetric(BinaryAccuracy())
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,))))
>>> fig_, ax_ = metric.plot(values)
_images/min_max-2.png
reset()[source]

Set max_val and min_val to the initialization bounds and resets the base metric.

Return type:

None

update(*args, **kwargs)[source]

Update the underlying metric.

Return type:

None

Multi-output Wrapper

Module Interface

class torchmetrics.wrappers.MultioutputWrapper(base_metric, num_outputs, output_dim=-1, remove_nans=True, squeeze_outputs=True)[source]

Wrap a base metric to enable it to support multiple outputs.

Several torchmetrics metrics, such as SpearmanCorrCoef lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn’t support any aggregation across outputs. This means if you set num_outputs to 2, .compute() will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped.

In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When remove_nans is passed, the class will remove the intersection of NaN containing “rows” upon each update for each output. For example, suppose a user uses MultioutputWrapper to wrap torchmetrics.regression.r2.R2Score with 2 outputs, one of which occasionally has missing labels for classes like R2Score is that this class supports removing NaN values (parameter remove_nans) on a per-output basis. When remove_nans is passed the wrapper will remove all rows

Parameters:
  • base_metric (Metric) – Metric being wrapped.

  • num_outputs (int) – Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track.

  • output_dim (int) – Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as Accuracy where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.

  • remove_nans (bool) – Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension (N, ...) where N represents the length of the batch or dataset being passed in.

  • squeeze_outputs (bool) – If True, will squeeze the 1-item dimensions left after index_select is applied. This is sometimes unnecessary but harmless for metrics such as R2Score but useful for certain classification metrics that can’t handle additional 1-item dimensions.

Example

>>> # Mimic R2Score in `multioutput`, `raw_values` mode:
>>> import torch
>>> from torchmetrics.wrappers import MultioutputWrapper
>>> from torchmetrics.regression import R2Score
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
compute()[source]

Compute metrics.

Return type:

Tensor

forward(*args, **kwargs)[source]

Call underlying forward methods and aggregate the results if they’re non-null.

We override this method to ensure that state variables get copied over on the underlying metrics.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MultioutputWrapper
>>> from torchmetrics.regression import R2Score
>>> metric = MultioutputWrapper(R2Score(), 2)
>>> metric.update(torch.randn(20, 2), torch.randn(20, 2))
>>> fig_, ax_ = metric.plot()
_images/multi_output_wrapper-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import MultioutputWrapper
>>> from torchmetrics.regression import R2Score
>>> metric = MultioutputWrapper(R2Score(), 2)
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randn(20, 2), torch.randn(20, 2)))
>>> fig_, ax_ = metric.plot(values)
_images/multi_output_wrapper-2.png
reset()[source]

Reset all underlying metrics.

Return type:

None

update(*args, **kwargs)[source]

Update each underlying metric with the corresponding output.

Return type:

None

Multi-task Wrapper

Module Interface

class torchmetrics.wrappers.MultitaskWrapper(task_metrics)[source]

Wrapper class for computing different metrics on different tasks in the context of multitask learning.

In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary. Note that only metrics where the signature of update follows the stardard preds, target is supported.

Parameters:

task_metrics (Dict[str, Union[Metric, MetricCollection]]) – Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the names of the tasks, and the values represent the metrics to use for each task.

Raises:
  • TypeError – If argument task_metrics is not an dictionary

  • TypeError – If not all values in the task_metrics dictionary is instances of Metric or MetricCollection

Example (with a single metric per class):
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... })
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
Example (with several metrics per task):
>>> import torch
>>> from torchmetrics import MetricCollection
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
>>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()),
...     "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError())
... })
>>> metrics.update(preds, targets)
>>> metrics.compute()
{'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)},
 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
compute()[source]

Compute metrics for all tasks.

Return type:

Dict[str, Any]

forward(task_preds, task_targets)[source]

Call underlying forward methods for all tasks and return the result as a dictionary.

Return type:

Dict[str, Any]

plot(val=None, axes=None)[source]

Plot a single or multiple values from the metric.

All tasks’ results are plotted on individual axes.

Parameters:
  • val (Union[Dict, Sequence[Dict], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • axes (Optional[Sequence[Axes]]) – Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects. If not provided, will create them.

Return type:

Sequence[Tuple[Figure, Union[Axes, ndarray]]]

Returns:

Sequence of tuples with Figure and Axes object for each task.

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... })
>>> metrics.update(preds, targets)
>>> value = metrics.compute()
>>> fig_, ax_ = metrics.plot(value)
_images/multi_task_wrapper-1_00.png
_images/multi_task_wrapper-1_01.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import MultitaskWrapper
>>> from torchmetrics.regression import MeanSquaredError
>>> from torchmetrics.classification import BinaryAccuracy
>>>
>>> classification_target = torch.tensor([0, 1, 0])
>>> regression_target = torch.tensor([2.5, 5.0, 4.0])
>>> targets = {"Classification": classification_target, "Regression": regression_target}
>>>
>>> classification_preds = torch.tensor([0, 0, 1])
>>> regression_preds = torch.tensor([3.0, 5.0, 2.5])
>>> preds = {"Classification": classification_preds, "Regression": regression_preds}
>>>
>>> metrics = MultitaskWrapper({
...     "Classification": BinaryAccuracy(),
...     "Regression": MeanSquaredError()
... })
>>> values = []
>>> for _ in range(10):
...     values.append(metrics(preds, targets))
>>> fig_, ax_ = metrics.plot(values)
_images/multi_task_wrapper-2_00.png
_images/multi_task_wrapper-2_01.png
reset()[source]

Reset all underlying metrics.

Return type:

None

update(task_preds, task_targets)[source]

Update each task’s metric with its corresponding pred and target.

Parameters:
  • task_preds (Dict[str, Tensor]) – Dictionary associating each task to a Tensor of pred.

  • task_targets (Dict[str, Tensor]) – Dictionary associating each task to a Tensor of target.

Return type:

None

Running

Module Interface

class torchmetrics.wrappers.Running(base_metric, window=5)[source]

Running wrapper for metrics.

Using this wrapper allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don’t want to wait for the whole training to finish to get epoch level estimates.

The running window is defined by the window argument. The window is a fixed size and this wrapper will store a duplicate of the underlying metric state for each value in the window. Thus memory usage will increase linearly with window size. Use accordingly. Also note that the running only works with metrics that have the full_state_update set to False.

Importantly, the wrapper does not alter the value of the forward method of the underlying metric. Thus, forward will still return the value on the current batch. To get the running value call compute instead.

Parameters:
  • base_metric (Metric) – The metric to wrap.

  • window (int) – The size of the running window.

Example

# Single metric >>> from torch import tensor >>> from torchmetrics.wrappers import Running >>> from torchmetrics.aggregation import SumMetric >>> metric = Running(SumMetric(), window=3) >>> for i in range(6): … current_val = metric(tensor([i])) … running_val = metric.compute() … total_val = tensor(sum(list(range(i+1)))) # value we would get from compute without running … print(f”{current_val=}, {running_val=}, {total_val=}”) current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0) current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1) current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3) current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6) current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10) current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)

Example

# Metric collection >>> from torch import tensor >>> from torchmetrics.wrappers import Running >>> from torchmetrics import MetricCollection >>> from torchmetrics.aggregation import SumMetric, MeanMetric >>> # note that running is input to collection, not the other way >>> metric = MetricCollection({“sum”: Running(SumMetric(), 3), “mean”: Running(MeanMetric(), 3)}) >>> for i in range(6): … current_val = metric(tensor([i])) … running_val = metric.compute() … print(f”{current_val=}, {running_val=}”) current_val={‘mean’: tensor(0.), ‘sum’: tensor(0.)}, running_val={‘mean’: tensor(0.), ‘sum’: tensor(0.)} current_val={‘mean’: tensor(1.), ‘sum’: tensor(1.)}, running_val={‘mean’: tensor(0.5000), ‘sum’: tensor(1.)} current_val={‘mean’: tensor(2.), ‘sum’: tensor(2.)}, running_val={‘mean’: tensor(1.), ‘sum’: tensor(3.)} current_val={‘mean’: tensor(3.), ‘sum’: tensor(3.)}, running_val={‘mean’: tensor(2.), ‘sum’: tensor(6.)} current_val={‘mean’: tensor(4.), ‘sum’: tensor(4.)}, running_val={‘mean’: tensor(3.), ‘sum’: tensor(9.)} current_val={‘mean’: tensor(5.), ‘sum’: tensor(5.)}, running_val={‘mean’: tensor(4.), ‘sum’: tensor(12.)}

forward(*args, **kwargs)[source]

Forward input to the underlying metric and save state afterwards.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.wrappers import Running
>>> from torchmetrics.aggregation import SumMetric
>>> metric = Running(SumMetric(), 2)
>>> metric.update(torch.randn(20, 2))
>>> fig_, ax_ = metric.plot()
_images/running-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.wrappers import Running
>>> from torchmetrics.aggregation import SumMetric
>>> metric = Running(SumMetric(), 2)
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.randn(20, 2)))
>>> fig_, ax_ = metric.plot(values)
_images/running-2.png
reset()[source]

Reset metric.

Return type:

None

torchmetrics.Metric

The base Metric class is an abstract base class that are used as the building block for all other Module metrics.

class torchmetrics.Metric(**kwargs)[source]

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes

The three core methods of the base class are * add_state() * forward() * reset()

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update() * compute()

Parameters:

kwargs (Any) –

additional keyword arguments, see Advanced metric settings for more info.

  • compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step: If metric state should synchronize on forward(). Default is False

  • process_group: The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute: If metric state should synchronize when compute is called. Default is True

  • compute_with_cache: If results from compute should be cached. Default is False

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Add metric state variable. Only used by subclasses.

Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if name is "my_state" then its value can be accessed from an instance metric as metric.my_state. Metric states behave like buffers and parameters of Module as they are also updated when .to() is called. Unlike parameters and buffers, metric states are not by default saved in the modules state_dict.

Parameters:
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Return type:

None

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is Tensor, the synced value will be a stacked Tensor across the process dimension if the metric state was a Tensor. The original Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Note

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Raises:
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", "min", "max" or None.

clone()[source]

Make a copy of the metric.

Return type:

Metric

abstract compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Return type:

Any

double()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

float()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

forward(*args, **kwargs)[source]

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Parameters:
  • args (Any) – Any arguments as required by the metric update method.

  • kwargs (Any) – Any keyword arguments as required by the metric update method.

Return type:

Any

Returns:

The output of the compute method evaluated on the current batch.

Raises:

TorchMetricsUserError – If the metric is already synced and forward is called again.

half()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

persistent(mode=False)[source]

Change post-init if metric states should be saved to its state_dict.

Return type:

None

plot(*_, **__)[source]

Override this method plot the metric value.

Return type:

Any

reset()[source]

Reset metric state variables to their default value.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as string or dtype object

Return type:

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Get the current state of metric as an dictionary.

Parameters:
  • destination (Optional[Dict[str, Any]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned.

  • prefix (str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.

  • keep_vars (bool) – by default the Tensor returned in the state dict are detached from autograd. If set to True, detaching will not be performed.

Return type:

Dict[str, Any]

sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Raises:

TorchMetricsUserError – If the metric is already synced and sync is called again.

Return type:

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]

Context manager to synchronize states.

This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the syncronized state.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type:

Generator

type(dst_type)[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters:

should_unsync (bool) – Whether to perform unsync

Return type:

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type:

None

property device: device

Return the device of the metric.

property metric_state: Dict[str, Union[List[Tensor], Tensor]]

Get the current state of the metric.

property update_called: bool

Returns True if update or forward has been called initialization or last reset.

property update_count: int

Get the number of times update and/or forward has been called since initialization or last reset.

torchmetrics.utilities.data

select_topk

torchmetrics.utilities.data.select_topk(prob_tensor, topk=1, dim=1)[source]

Convert a probability tensor to binary by selecting top-k the highest entries.

Parameters:
  • prob_tensor (Tensor) – dense tensor of shape [..., C, ...], where C is in the position defined by the dim argument

  • topk (int) – number of the highest entries to turn into 1s

  • dim (int) – dimension on which to compare entries

Return type:

Tensor

Returns:

A binary tensor of the same shape as the input tensor of type torch.int32

Example

>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
>>> select_topk(x, topk=2)
tensor([[0, 1, 1],
        [1, 1, 0]], dtype=torch.int32)

to_categorical

torchmetrics.utilities.data.to_categorical(x, argmax_dim=1)[source]

Convert a tensor of probabilities to a dense label tensor.

Parameters:
  • x (Tensor) – probabilities to get the categorical label [N, d1, d2, …]

  • argmax_dim (int) – dimension to apply

Return type:

Tensor

Returns:

A tensor with categorical labels [N, d2, …]

Example

>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
>>> to_categorical(x)
tensor([1, 0])

to_onehot

torchmetrics.utilities.data.to_onehot(label_tensor, num_classes=None)[source]

Convert a dense label tensor to one-hot format.

Parameters:
  • label_tensor (Tensor) – dense label tensor, with shape [N, d1, d2, …]

  • num_classes (Optional[int]) – number of classes C

Return type:

Tensor

Returns:

A sparse label tensor with shape [N, C, d1, d2, …]

Example

>>> x = torch.tensor([1, 2, 3])
>>> to_onehot(x)
tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]])

torchmetrics.utilities.exceptions

TorchMetricsUserError

class torchmetrics.utilities.exceptions.TorchMetricsUserError[source]

Error used to inform users of a wrong combination of Metric API calls.

TorchMetricsUserWarning

class torchmetrics.utilities.exceptions.TorchMetricsUserWarning[source]

Error used to inform users of specific warnings due to the torchmetrics API.

TorchMetrics Governance

This document describes governance processes we follow in developing TorchMetrics.

Persons of Interest

Leads
Core Maintainers
Alumni

Releases

We release a new minor version (e.g., 0.5.0) every few months and bugfix releases if needed. The minor versions contain new features, API changes, deprecations, removals, potential backward-incompatible changes and also all previous bugfixes included in any bugfix release. With every release, we publish a changelog where we list additions, removals, changed functionality and fixes.

Project Management and Decision Making

The decision what goes into a release is governed by the staff contributors and leaders of TorchMetrics development. Whenever possible, discussion happens publicly on GitHub and includes the whole community. When a consensus is reached, staff and core contributors assign milestones and labels to the issue and/or pull request and start tracking the development. It is possible that priorities change over time.

Commits to the project are exclusively to be added by pull requests on GitHub and anyone in the community is welcome to review them. However, reviews submitted by code owners have higher weight and it is necessary to get the approval of code owners before a pull request can be merged. Additional requirements may apply case by case.

API Evolution

TorchMetrics development is driven by research and best practices in a rapidly developing field of AI and machine learning. Change is inevitable and when it happens, the Torchmetric team is committed to minimizing user friction and maximizing ease of transition from one version to the next. We take backward compatibility and reproducibility very seriously.

For API removal, renaming or other forms of backward-incompatible changes, the procedure is:

  1. A deprecation process is initiated at version X, producing warning messages at runtime and in the documentation.

  2. Calls to the deprecated API remain unchanged in their function during the deprecation phase.

  3. One minor versions in the future at version X+1 the breaking change takes effect.

The “X+1” rule is a recommendation and not a strict requirement. Longer deprecation cycles may apply for some cases.

Contributor Covenant Code of Conduct

Our Pledge

In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.

Our Standards

Examples of behavior that contributes to creating a positive environment include:

  • Using welcoming and inclusive language

  • Being respectful of differing viewpoints and experiences

  • Gracefully accepting constructive criticism

  • Focusing on what is best for the community

  • Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

  • The use of sexualized language or imagery and unwelcome sexual attention or advances

  • Trolling, insulting/derogatory comments, and personal or political attacks

  • Public or private harassment

  • Publishing others’ private information, such as a physical or electronic address, without explicit permission

  • Other conduct which could reasonably be considered inappropriate in a professional setting

Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.

Scope

This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.

Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at waf2107@columbia.edu. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership.

Attribution

This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html

For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq

Contributing

Welcome to the Torchmetrics community! We’re building largest collection of native pytorch metrics, with the goal of reducing boilerplate and increasing reproducibility.

Contribution Types

We are always looking for help implementing new features or fixing bugs.

Bug Fixes:
  1. If you find a bug please submit a github issue.

    • Make sure the title explains the issue.

    • Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.

    • Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data.

  2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach:

    • Convert your minimal code example to a unit/integration test with assert on expected results.

    • Start by debugging the issue… You can run just this particular test in your IDE and draft a fix.

    • Verify that your test case fails on the master branch and only passes with the fix applied.

  3. Submit a PR!

Note, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :]

New Features:
  1. Submit a github issue - describe what is the motivation of such feature (adding the use case or an example is helpful).

  2. Let’s discuss to determine the feature scope.

  3. Submit a PR! We recommend test driven approach to adding new features as well:

    • Write a test for the functionality you want to add.

    • Write the functional code until the test passes.

  4. Add/update the relevant tests!

  • This PR is a good example for adding a new metric

Test cases:

Want to keep Torchmetrics healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features. One of the core values of torchmetrics is that our users can trust our metric implementation. We can only guarantee this if our metrics are well tested.


Guidelines

Developments scripts

To build the documentation locally, simply execute the following commands from project root (only for Unix):

  • make clean cleans repo from temp/generated files

  • make docs builds documentation under docs/build/html

  • make test runs all project’s tests with coverage

Original code

All added or edited code shall be the own original work of the particular contributor. If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code’s author. For example - This code is inspired from http://.... In case you adding new dependencies, make sure that they are compatible with the actual Torchmetrics license (ie. dependencies should be at least as permissive as the Torchmetrics license).

Coding Style
  1. Use f-strings for output formation (except logging when we stay with lazy logging.info("Hello %s!", name).

  2. You can use pre-commit to make sure your code style is correct.

Documentation

We are using Sphinx with Napoleon extension. Moreover, we set Google style to follow with type convention.

See following short example of a sample function taking one position string and optional

from typing import Optional


def my_func(param_a: int, param_b: Optional[float] = None) -> str:
    """Sample function.

    Args:
        param_a: first parameter
        param_b: second parameter

    Return:
        sum of both numbers

    Example:
        Sample doctest example...
        >>> my_func(1, 2)
        3

    .. note:: If you want to add something.
    """
    p = param_b if param_b else 0
    return str(param_a + p)

When updating the docs make sure to build them first locally and visually inspect the html files (in the browser) for formatting errors. In certain cases, a missing blank line or a wrong indent can lead to a broken layout. Run these commands

make docs

and open docs/build/html/index.html in your browser.

Notes:

  • You need to have LaTeX installed for rendering math equations. You can for example install TeXLive by doing one of the following:

    • on Ubuntu (Linux) run apt-get install texlive or otherwise follow the instructions on the TeXLive website

    • use the RTD docker image

  • with PL used class meta you need to use python 3.7 or higher

When you send a PR the continuous integration will run tests and build the docs.

Testing

Local: Testing your work locally will help you speed up the process since it allows you to focus on particular (failing) test-cases. To setup a local development environment, install both local and test dependencies:

python -m pip install -r requirements/test.txt
python -m pip install pre-commit

You can run the full test-case in your terminal via this make script:

make test
# or natively
python -m pytest torchmetrics tests

Note: if your computer does not have multi-GPU nor TPU these tests are skipped.

GitHub Actions: For convenience, you can also use your own GHActions building which will be triggered with each commit. This is useful if you do not test against all required dependency versions.

Changelog

All notable changes to this project will be documented in this file.

The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.

Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.

[1.1.2] - 2023-09-11

[1.1.2] - Fixed
  • Fixed tie breaking in ndcg metric (#2031)

  • Fixed bug in BootStrapper when very few samples were evaluated that could lead to crash (#2052)

  • Fixed bug when creating multiple plots that lead to not all plots being shown (#2060)

  • Fixed performance issues in RecallAtFixedPrecision for large batch sizes (#2042)

  • Fixed bug related to MetricCollection used with custom metrics have prefix/postfix attributes (#2070)

[1.1.1] - 2023-08-29

[1.1.1] - Added
  • Added average argument to MeanAveragePrecision (#2018)

[1.1.1] - Fixed
  • Fixed bug in PearsonCorrCoef is updated on single samples at a time (#2019)

  • Fixed support for pixelwise MSE (#2017)

  • Fixed bug in MetricCollection when used with multiple metrics that return dicts with same keys (#2027)

  • Fixed bug in detection intersection metrics when class_metrics=True resulting in wrong values (#1924)

  • Fixed missing attributes higher_is_better, is_differentiable for some metrics (#2028)

[1.1.0] - 2023-08-22

[1.1.0] - Added
  • Added source aggregated signal-to-distortion ratio (SA-SDR) metric (#1882

  • Added VisualInformationFidelity to image package (#1830)

  • Added EditDistance to text package (#1906)

  • Added top_k argument to RetrievalMRR in retrieval package (#1961)

  • Added support for evaluating "segm" and "bbox" detection in MeanAveragePrecision at the same time (#1928)

  • Added PerceptualPathLength to image package (#1939)

  • Added support for multioutput evaluation in MeanSquaredError (#1937)

  • Added argument extended_summary to MeanAveragePrecision such that precision, recall, iou can be easily returned (#1983)

  • Added warning to ClipScore if long captions are detected and truncate (#2001)

  • Added CLIPImageQualityAssessment to multimodal package (#1931)

  • Added new property metric_state to all metrics for users to investigate currently stored tensors in memory (#2006)

[1.0.3] - 2023-08-08

[1.0.3] - Added
  • Added warning to MeanAveragePrecision if too many detections are observed (#1978)

[1.0.3] - Fixed
  • Fix support for int input for when multidim_average="samplewise" in classification metrics (#1977)

  • Fixed x/y labels when plotting confusion matrices (#1976)

  • Fixed IOU compute in cuda (#1982)

[1.0.2] - 2023-08-02

[1.0.2] - Added
  • Added warning to PearsonCorrCoeff if input has a very small variance for its given dtype (#1926)

[1.0.2] - Changed
  • Changed all non-task specific classification metrics to be true subtypes of Metric (#1963)

[1.0.2] - Fixed
  • Fixed bug in CalibrationError where calculations for double precision input was performed in float precision (#1919)

  • Fixed bug related to the prefix/postfix arguments in MetricCollection and ClasswiseWrapper being duplicated (#1918)

  • Fixed missing AUC score when plotting classification metrics that support the score argument (#1948)

[1.0.1] - 2023-07-13

[1.0.1] - Fixed
  • Fixes corner case when using MetricCollection together with aggregation metrics (#1896)

  • Fixed the use of max_fpr in AUROC metric when only one class is present (#1895)

  • Fixed bug related to empty predictions for IntersectionOverUnion metric (#1892)

  • Fixed bug related to MeanMetric and broadcasting of weights when Nans are present (#1898)

  • Fixed bug related to expected input format of pycoco in MeanAveragePrecision (#1913)

[1.0.0] - 2023-07-04

[1.0.0] - Added
  • Added prefix and postfix arguments to ClasswiseWrapper (#1866)

  • Added speech-to-reverberation modulation energy ratio (SRMR) metric (#1792, #1872)

  • Added new global arg compute_with_cache to control caching behaviour after compute method (#1754)

  • Added ComplexScaleInvariantSignalNoiseRatio for audio package (#1785)

  • Added Running wrapper for calculate running statistics (#1752)

  • AddedRelativeAverageSpectralError and RootMeanSquaredErrorUsingSlidingWindow to image package (#816)

  • Added support for SpecificityAtSensitivity Metric (#1432)

  • Added support for plotting of metrics through .plot() method ( #1328, #1481, #1480, #1490, #1581, #1585, #1593, #1600, #1605, #1610, #1609, #1621, #1624, #1623, #1638, #1631, #1650, #1639, #1660, #1682, #1786, )

  • Added support for plotting of audio metrics through .plot() method (#1434)

  • Added classes to output from MAP metric (#1419)

  • Added Binary group fairness metrics to classification package (#1404)

  • Added MinkowskiDistance to regression package (#1362)

  • Added pairwise_minkowski_distance to pairwise package (#1362)

  • Added new detection metric PanopticQuality ( #929, #1527, )

  • Added PSNRB metric (#1421)

  • Added ClassificationTask Enum and use in metrics (#1479)

  • Added ignore_index option to exact_match metric (#1540)

  • Add parameter top_k to RetrievalMAP (#1501)

  • Added support for deterministic evaluation on GPU for metrics that uses torch.cumsum operator (#1499)

  • Added support for plotting of aggregation metrics through .plot() method (#1485)

  • Added support for python 3.11 (#1612)

  • Added support for auto clamping of input for metrics that uses the data_range ([#1606](argument https://github.com/Lightning-AI/metrics/pull/1606))

  • Added ModifiedPanopticQuality metric to detection package (#1627)

  • Added PrecisionAtFixedRecall metric to classification package (#1683)

  • Added multiple metrics to detection package (#1284)

    • IntersectionOverUnion

    • GeneralizedIntersectionOverUnion

    • CompleteIntersectionOverUnion

    • DistanceIntersectionOverUnion

  • Added MultitaskWrapper to wrapper package (#1762)

  • Added RelativeSquaredError metric to regression package (#1765)

  • Added MemorizationInformedFrechetInceptionDistance metric to image package (#1580)

[1.0.0] - Changed
  • Changed permutation_invariant_training to allow using a 'permutation-wise' metric function (#1794)

  • Changed update_count and update_called from private to public methods (#1370)

  • Raise exception for invalid kwargs in Metric base class (#1427)

  • Extend EnumStr raising ValueError for invalid value (#1479)

  • Improve speed and memory consumption of binned PrecisionRecallCurve with large number of samples (#1493)

  • Changed __iter__ method from raising NotImplementedError to TypeError by setting to None (#1538)

  • FID metric will now raise an error if too few samples are provided (#1655)

  • Allowed FID with torch.float64 (#1628)

  • Changed LPIPS implementation to no more rely on third-party package (#1575)

  • Changed FID matrix square root calculation from scipy to torch (#1708)

  • Changed calculation in PearsonCorrCoeff to be more robust in certain cases (#1729)

  • Changed MeanAveragePrecision to pycocotools backend (#1832)

[1.0.0] - Deprecated
[1.0.0] - Removed
  • Support for python 3.7 (#1640)

[1.0.0] - Fixed
  • Fixed support in MetricTracker for MultioutputWrapper and nested structures (#1608)

  • Fixed restrictive check in PearsonCorrCoef (#1649)

  • Fixed integration with jsonargparse and LightningCLI (#1651)

  • Fixed corner case in calibration error for zero confidence input (#1648)

  • Fix precision-recall curve based computations for float target (#1642)

  • Fixed missing kwarg squeeze in MultiOutputWrapper (#1675)

  • Fixed padding removal for 3d input in MSSSIM (#1674)

  • Fixed max_det_threshold in MAP detection (#1712)

  • Fixed states being saved in metrics that use register_buffer (#1728)

  • Fixed states not being correctly synced and device transfered in MeanAveragePrecision for iou_type="segm" (#1763)

  • Fixed use of prefix and postfix in nested MetricCollection (#1773)

  • Fixed ax plotting logging in `MetricCollection (#1783)

  • Fixed lookup for punkt sources being downloaded in RougeScore (#1789)

  • Fixed integration with lightning for CompositionalMetric (#1761)

  • Fixed several bugs in SpectralDistortionIndex metric (#1808)

  • Fixed bug for corner cases in MatthewsCorrCoef ( #1812, #1863 )

  • Fixed support for half precision in PearsonCorrCoef (#1819)

  • Fixed number of bugs related to average="macro" in classification metrics (#1821)

  • Fixed off-by-one issue when ignore_index = num_classes + 1 in Multiclass-jaccard (#1860)

[0.11.4] - 2023-03-10

[0.11.4] - Fixed
  • Fixed evaluation of R2Score with near constant target (#1576)

  • Fixed dtype conversion when metric is submodule (#1583)

  • Fixed bug related to top_k>1 and ignore_index!=None in StatScores based metrics (#1589)

  • Fixed corner case for PearsonCorrCoef when running in ddp mode but only on single device (#1587)

  • Fixed overflow error for specific cases in MAP when big areas are calculated (#1607)

[0.11.3] - 2023-02-28

[0.11.3] - Fixed
  • Fixed classification metrics for byte input (#1521)

  • Fixed the use of ignore_index in MulticlassJaccardIndex (#1386)

[0.11.2] - 2023-02-21

[0.11.2] - Fixed
  • Fixed compatibility between XLA in _bincount function (#1471)

  • Fixed type hints in methods belonging to MetricTracker wrapper (#1472)

  • Fixed multilabel in ExactMatch (#1474)

[0.11.1] - 2023-01-30

[0.11.1] - Fixed
  • Fixed type checking on the maximize parameter at the initialization of MetricTracker (#1428)

  • Fixed mixed precision autocast for SSIM metric (#1454)

  • Fixed checking for nltk.punkt in RougeScore if a machine is not online (#1456)

  • Fixed wrongly reset method in MultioutputWrapper (#1460)

  • Fixed dtype checking in PrecisionRecallCurve for target tensor (#1457)

[0.11.0] - 2022-11-30

[0.11.0] - Added
  • Added MulticlassExactMatch to classification metrics (#1343)

  • Added TotalVariation to image package (#978)

  • Added CLIPScore to new multimodal package (#1314)

  • Added regression metrics:

    • KendallRankCorrCoef (#1271)

    • LogCoshError (#1316)

  • Added new nominal metrics:

  • Added option to pass distributed_available_fn to metrics to allow checks for custom communication backend for making dist_sync_fn actually useful (#1301)

  • Added normalize argument to Inception, FID, KID metrics (#1246)

[0.11.0] - Changed
  • Changed minimum Pytorch version to be 1.8 (#1263)

  • Changed interface for all functional and modular classification metrics after refactor (#1252)

[0.11.0] - Removed
  • Removed deprecated BinnedAveragePrecision, BinnedPrecisionRecallCurve, RecallAtFixedPrecision (#1251)

  • Removed deprecated LabelRankingAveragePrecision, LabelRankingLoss and CoverageError (#1251)

  • Removed deprecated KLDivergence and AUC (#1251)

[0.11.0] - Fixed
  • Fixed precision bug in pairwise_euclidean_distance (#1352)

[0.10.3] - 2022-11-16

[0.10.3] - Fixed
  • Fixed bug in Metrictracker.best_metric when return_step=False (#1306)

  • Fixed bug to prevent users from going into an infinite loop if trying to iterate of a single metric (#1320)

[0.10.2] - 2022-10-31

[0.10.2] - Changed
  • Changed in-place operation to out-of-place operation in pairwise_cosine_similarity (#1288)

[0.10.2] - Fixed
  • Fixed high memory usage for certain classification metrics when average='micro' (#1286)

  • Fixed precision problems when structural_similarity_index_measure was used with autocast (#1291)

  • Fixed slow performance for confusion matrix based metrics (#1302)

  • Fixed restrictive dtype checking in spearman_corrcoef when used with autocast (#1303)

[0.10.1] - 2022-10-21

[0.10.1] - Fixed
  • Fixed broken clone method for classification metrics (#1250)

  • Fixed unintentional downloading of nltk.punkt when lsum not in rouge_keys (#1258)

  • Fixed type casting in MAP metric between bool and float32 (#1150)

[0.10.0] - 2022-10-04

[0.10.0] - Added
  • Added a new NLP metric InfoLM (#915)

  • Added Perplexity metric (#922)

  • Added ConcordanceCorrCoef metric to regression package (#1201)

  • Added argument normalize to LPIPS metric (#1216)

  • Added support for multiprocessing of batches in PESQ metric (#1227)

  • Added support for multioutput in PearsonCorrCoef and SpearmanCorrCoef (#1200)

[0.10.0] - Changed
[0.10.0] - Deprecated
  • Deprecated BinnedAveragePrecision, BinnedPrecisionRecallCurve, BinnedRecallAtFixedPrecision (#1163)

    • BinnedAveragePrecision -> use AveragePrecision with thresholds arg

    • BinnedPrecisionRecallCurve -> use AveragePrecisionRecallCurve with thresholds arg

    • BinnedRecallAtFixedPrecision -> use RecallAtFixedPrecision with thresholds arg

  • Renamed and refactored LabelRankingAveragePrecision, LabelRankingLoss and CoverageError (#1167)

    • LabelRankingAveragePrecision -> MultilabelRankingAveragePrecision

    • LabelRankingLoss -> MultilabelRankingLoss

    • CoverageError -> MultilabelCoverageError

  • Deprecated KLDivergence and AUC from classification package (#1189)

    • KLDivergence moved to regression package

    • Instead of AUC use torchmetrics.utils.compute.auc

[0.10.0] - Fixed
  • Fixed a bug in ssim when return_full_image=True where the score was still reduced (#1204)

  • Fixed MPS support for:

  • Fixed bug in ClasswiseWrapper such that compute gave wrong result (#1225)

  • Fixed synchronization of empty list states (#1219)

[0.9.3] - 2022-08-22

[0.9.3] - Added
  • Added global option sync_on_compute to disable automatic synchronization when compute is called (#1107)

[0.9.3] - Fixed
  • Fixed missing reset in ClasswiseWrapper (#1129)

  • Fixed JaccardIndex multi-label compute (#1125)

  • Fix SSIM propagate device if gaussian_kernel is False, add test (#1149)

[0.9.2] - 2022-06-29

[0.9.2] - Fixed
  • Fixed mAP calculation for areas with 0 predictions (#1080)

  • Fixed bug where avg precision state and auroc state was not merge when using MetricCollections (#1086)

  • Skip box conversion if no boxes are present in MeanAveragePrecision (#1097)

  • Fixed inconsistency in docs and code when setting average="none" in AvaragePrecision metric (#1116)

[0.9.1] - 2022-06-08

[0.9.1] - Added
  • Added specific RuntimeError when metric object is on the wrong device (#1056)

  • Added an option to specify own n-gram weights for BLEUScore and SacreBLEUScore instead of using uniform weights only. (#1075)

[0.9.1] - Fixed
  • Fixed aggregation metrics when input only contains zero (#1070)

  • Fixed TypeError when providing superclass arguments as kwargs (#1069)

  • Fixed bug related to state reference in metric collection when using compute groups (#1076)

[0.9.0] - 2022-05-30

[0.9.0] - Added
  • Added RetrievalPrecisionRecallCurve and RetrievalRecallAtFixedPrecision to retrieval package (#951)

  • Added class property full_state_update that determines forward should call update once or twice ( #984, #1033)

  • Added support for nested metric collections (#1003)

  • Added Dice to classification package (#1021)

  • Added support to segmentation type segm as IOU for mean average precision (#822)

[0.9.0] - Changed
  • Renamed reduction argument to average in Jaccard score and added additional options (#874)

[0.9.0] - Removed
[0.9.0] - Fixed
  • Fixed non-empty state dict for a few metrics (#1012)

  • Fixed bug when comparing states while finding compute groups (#1022)

  • Fixed torch.double support in stat score metrics (#1023)

  • Fixed FID calculation for non-equal size real and fake input (#1028)

  • Fixed case where KLDivergence could output Nan (#1030)

  • Fixed deterministic for PyTorch<1.8 (#1035)

  • Fixed default value for mdmc_average in Accuracy (#1036)

  • Fixed missing copy of property when using compute groups in MetricCollection (#1052)

[0.8.2] - 2022-05-06

[0.8.2] - Fixed
  • Fixed multi device aggregation in PearsonCorrCoef (#998)

  • Fixed MAP metric when using custom list of thresholds (#995)

  • Fixed compatibility between compute groups in MetricCollection and prefix/postfix arg (#1007)

  • Fixed compatibility with future Pytorch 1.12 in safe_matmul (#1011, #1014)

[0.8.1] - 2022-04-27

[0.8.1] - Changed
  • Reimplemented the signal_distortion_ratio metric, which removed the absolute requirement of fast-bss-eval (#964)

[0.8.1] - Fixed
  • Fixed “Sort currently does not support bool dtype on CUDA” error in MAP for empty preds (#983)

  • Fixed BinnedPrecisionRecallCurve when thresholds argument is not provided (#968)

  • Fixed CalibrationError to work on logit input (#985)

[0.8.0] - 2022-04-14

[0.8.0] - Added
  • Added WeightedMeanAbsolutePercentageError to regression package (#948)

  • Added new classification metrics:

    • CoverageError (#787)

    • LabelRankingAveragePrecision and LabelRankingLoss (#787)

  • Added new image metric:

    • SpectralAngleMapper (#885)

    • ErrorRelativeGlobalDimensionlessSynthesis (#894)

    • UniversalImageQualityIndex (#824)

    • SpectralDistortionIndex (#873)

  • Added support for MetricCollection in MetricTracker (#718)

  • Added support for 3D image and uniform kernel in StructuralSimilarityIndexMeasure (#818)

  • Added smart update of MetricCollection (#709)

  • Added ClasswiseWrapper for better logging of classification metrics with multiple output values (#832)

  • Added **kwargs argument for passing additional arguments to base class (#833)

  • Added negative ignore_index for the Accuracy metric (#362)

  • Added adaptive_k for the RetrievalPrecision metric (#910)

  • Added reset_real_features argument image quality assessment metrics (#722)

  • Added new keyword argument compute_on_cpu to all metrics (#867)

[0.8.0] - Changed
  • Made num_classes in jaccard_index a required argument (#853, #914)

  • Added normalizer, tokenizer to ROUGE metric (#838)

  • Improved shape checking of permutation_invariant_training (#864)

  • Allowed reduction None (#891)

  • MetricTracker.best_metric will now give a warning when computing on metric that do not have a best (#913)

[0.8.0] - Deprecated
  • Deprecated argument compute_on_step (#792)

  • Deprecated passing in dist_sync_on_step, process_group, dist_sync_fn direct argument (#833)

[0.8.0] - Removed
  • Removed support for versions of Pytorch-Lightning lower than v1.5 (#788)

  • Removed deprecated functions, and warnings in Text (#773)

    • WER and functional.wer

  • Removed deprecated functions and warnings in Image (#796)

    • SSIM and functional.ssim

    • PSNR and functional.psnr

  • Removed deprecated functions, and warnings in classification and regression (#806)

    • FBeta and functional.fbeta

    • F1 and functional.f1

    • Hinge and functional.hinge

    • IoU and functional.iou

    • MatthewsCorrcoef

    • PearsonCorrcoef

    • SpearmanCorrcoef

  • Removed deprecated functions, and warnings in detection and pairwise (#804)

    • MAP and functional.pairwise.manhatten

  • Removed deprecated functions, and warnings in Audio (#805)

    • PESQ and functional.audio.pesq

    • PIT and functional.audio.pit

    • SDR and functional.audio.sdr and functional.audio.si_sdr

    • SNR and functional.audio.snr and functional.audio.si_snr

    • STOI and functional.audio.stoi

  • Removed unused get_num_classes from torchmetrics.utilities.data (#914)

[0.8.0] - Fixed
  • Fixed device mismatch for MAP metric in specific cases (#950)

  • Improved testing speed (#820)

  • Fixed compatibility of ClasswiseWrapper with the prefix argument of MetricCollection (#843)

  • Fixed BestScore on GPU (#912)

  • Fixed Lsum computation for ROUGEScore (#944)

[0.7.3] - 2022-03-23

[0.7.3] - Fixed
  • Fixed unsafe log operation in TweedieDeviace for power=1 (#847)

  • Fixed bug in MAP metric related to either no ground truth or no predictions (#884)

  • Fixed ConfusionMatrix, AUROC and AveragePrecision on GPU when running in deterministic mode (#900)

  • Fixed NaN or Inf results returned by signal_distortion_ratio (#899)

  • Fixed memory leak when using update method with tensor where requires_grad=True (#902)

[0.7.2] - 2022-02-10

[0.7.2] - Fixed
  • Minor patches in JOSS paper.

[0.7.1] - 2022-02-03

[0.7.1] - Changed
  • Used torch.bucketize in calibration error when torch>1.8 for faster computations (#769)

  • Improve mAP performance (#742)

[0.7.1] - Fixed
  • Fixed check for available modules (#772)

  • Fixed Matthews correlation coefficient when the denominator is 0 (#781)

[0.7.0] - 2022-01-17

[0.7.0] - Added
  • Added NLP metrics:

    • MatchErrorRate (#619)

    • WordInfoLost and WordInfoPreserved (#630)

    • SQuAD (#623)

    • CHRFScore (#641)

    • TranslationEditRate (#646)

    • ExtendedEditDistance (#668)

  • Added MultiScaleSSIM into image metrics (#679)

  • Added Signal to Distortion Ratio (SDR) to audio package (#565)

  • Added MinMaxMetric to wrappers (#556)

  • Added ignore_index to retrieval metrics (#676)

  • Added support for multi references in ROUGEScore (#680)

  • Added a default VSCode devcontainer configuration (#621)

[0.7.0] - Changed
  • Scalar metrics will now consistently have additional dimensions squeezed (#622)

  • Metrics having third party dependencies removed from global import (#463)

  • Untokenized for BLEUScore input stay consistent with all the other text metrics (#640)

  • Arguments reordered for TER, BLEUScore, SacreBLEUScore, CHRFScore now expect input order as predictions first and target second (#696)

  • Changed dtype of metric state from torch.float to torch.long in ConfusionMatrix to accommodate larger values (#715)

  • Unify preds, target input argument’s naming across all text metrics (#723, #727)

    • bert, bleu, chrf, sacre_bleu, wip, wil, cer, ter, wer, mer, rouge, squad

[0.7.0] - Deprecated
  • Renamed IoU -> Jaccard Index (#662)

  • Renamed text WER metric (#714)

    • functional.wer -> functional.word_error_rate

    • WER -> WordErrorRate

  • Renamed correlation coefficient classes: (#710)

    • MatthewsCorrcoef -> MatthewsCorrCoef

    • PearsonCorrcoef -> PearsonCorrCoef

    • SpearmanCorrcoef -> SpearmanCorrCoef

  • Renamed audio STOI metric: (#753, #758)

    • audio.STOI to audio.ShortTimeObjectiveIntelligibility

    • functional.audio.stoi to functional.audio.short_time_objective_intelligibility

  • Renamed audio PESQ metrics: (#751)

    • functional.audio.pesq -> functional.audio.perceptual_evaluation_speech_quality

    • audio.PESQ -> audio.PerceptualEvaluationSpeechQuality

  • Renamed audio SDR metrics: (#711)

    • functional.sdr -> functional.signal_distortion_ratio

    • functional.si_sdr -> functional.scale_invariant_signal_distortion_ratio

    • SDR -> SignalDistortionRatio

    • SI_SDR -> ScaleInvariantSignalDistortionRatio

  • Renamed audio SNR metrics: (#712)

    • functional.snr -> functional.signal_distortion_ratio

    • functional.si_snr -> functional.scale_invariant_signal_noise_ratio

    • SNR -> SignalNoiseRatio

    • SI_SNR -> ScaleInvariantSignalNoiseRatio

  • Renamed F-score metrics: (#731, #740)

    • functional.f1 -> functional.f1_score

    • F1 -> F1Score

    • functional.fbeta -> functional.fbeta_score

    • FBeta -> FBetaScore

  • Renamed Hinge metric: (#734)

    • functional.hinge -> functional.hinge_loss

    • Hinge -> HingeLoss

  • Renamed image PSNR metrics (#732)

    • functional.psnr -> functional.peak_signal_noise_ratio

    • PSNR -> PeakSignalNoiseRatio

  • Renamed image PIT metric: (#737)

    • functional.pit -> functional.permutation_invariant_training

    • PIT -> PermutationInvariantTraining

  • Renamed image SSIM metric: (#747)

    • functional.ssim -> functional.scale_invariant_signal_noise_ratio

    • SSIM -> StructuralSimilarityIndexMeasure

  • Renamed detection MAP to MeanAveragePrecision metric (#754)

  • Renamed Fidelity & LPIPS image metric: (#752)

    • image.FID -> image.FrechetInceptionDistance

    • image.KID -> image.KernelInceptionDistance

    • image.LPIPS -> image.LearnedPerceptualImagePatchSimilarity

[0.7.0] - Removed
  • Removed embedding_similarity metric (#638)

  • Removed argument concatenate_texts from wer metric (#638)

  • Removed arguments newline_sep and decimal_places from rouge metric (#638)

[0.7.0] - Fixed
  • Fixed MetricCollection kwargs filtering when no kwargs are present in update signature (#707)

[0.6.2] - 2021-12-15

[0.6.2] - Fixed
  • Fixed torch.sort currently does not support bool dtype on CUDA (#665)

  • Fixed mAP properly checks if ground truths are empty (#684)

  • Fixed initialization of tensors to be on correct device for MAP metric (#673)

[0.6.1] - 2021-12-06

[0.6.1] - Changed
  • Migrate MAP metrics from pycocotools to PyTorch (#632)

  • Use torch.topk instead of torch.argsort in retrieval precision for speedup (#627)

[0.6.1] - Fixed
  • Fix empty predictions in MAP metric (#594, #610, #624)

  • Fix edge case of AUROC with average=weighted on GPU (#606)

  • Fixed forward in compositional metrics (#645)

[0.6.0] - 2021-10-28

[0.6.0] - Added
  • Added audio metrics:

    • Perceptual Evaluation of Speech Quality (PESQ) (#353)

    • Short-Time Objective Intelligibility (STOI) (#353)

  • Added Information retrieval metrics:

    • RetrievalRPrecision (#577)

    • RetrievalHitRate (#576)

  • Added NLP metrics:

    • SacreBLEUScore (#546)

    • CharErrorRate (#575)

  • Added other metrics:

    • Tweedie Deviance Score (#499)

    • Learned Perceptual Image Patch Similarity (LPIPS) (#431)

  • Added MAP (mean average precision) metric to new detection package (#467)

  • Added support for float targets in nDCG metric (#437)

  • Added average argument to AveragePrecision metric for reducing multi-label and multi-class problems (#477)

  • Added MultioutputWrapper (#510)

  • Added metric sweeping:

    • higher_is_better as constant attribute (#544)

    • higher_is_better to rest of codebase (#584)

  • Added simple aggregation metrics: SumMetric, MeanMetric, CatMetric, MinMetric, MaxMetric (#506)

  • Added pairwise submodule with metrics (#553)

    • pairwise_cosine_similarity

    • pairwise_euclidean_distance

    • pairwise_linear_similarity

    • pairwise_manhatten_distance

[0.6.0] - Changed
  • AveragePrecision will now as default output the macro average for multilabel and multiclass problems (#477)

  • half, double, float will no longer change the dtype of the metric states. Use metric.set_dtype instead (#493)

  • Renamed AverageMeter to MeanMetric (#506)

  • Changed is_differentiable from property to a constant attribute (#551)

  • ROC and AUROC will no longer throw an error when either the positive or negative class is missing. Instead return 0 score and give a warning

[0.6.0] - Deprecated
  • Deprecated functional.self_supervised.embedding_similarity in favour of new pairwise submodule

[0.6.0] - Removed
  • Removed dtype property (#493)

[0.6.0] - Fixed
  • Fixed bug in F1 with average='macro' and ignore_index!=None (#495)

  • Fixed bug in pit by using the returned first result to initialize device and type (#533)

  • Fixed SSIM metric using too much memory (#539)

  • Fixed bug where device property was not properly update when metric was a child of a module (#542)

[0.5.1] - 2021-08-30

[0.5.1] - Added
  • Added device and dtype properties (#462)

  • Added TextTester class for robustly testing text metrics (#450)

[0.5.1] - Changed
  • Added support for float targets in nDCG metric (#437)

[0.5.1] - Removed
  • Removed rouge-score as dependency for text package (#443)

  • Removed jiwer as dependency for text package (#446)

  • Removed bert-score as dependency for text package (#473)

[0.5.1] - Fixed
  • Fixed ranking of samples in SpearmanCorrCoef metric (#448)

  • Fixed bug where compositional metrics where unable to sync because of type mismatch (#454)

  • Fixed metric hashing (#478)

  • Fixed BootStrapper metrics not working on GPU (#462)

  • Fixed the semantic ordering of kernel height and width in SSIM metric (#474)

[0.5.0] - 2021-08-09

[0.5.0] - Added
  • Added Text-related (NLP) metrics:

  • Added MetricTracker wrapper metric for keeping track of the same metric over multiple epochs (#238)

  • Added other metrics:

    • Symmetric Mean Absolute Percentage error (SMAPE) (#375)

    • Calibration error (#394)

    • Permutation Invariant Training (PIT) (#384)

  • Added support in nDCG metric for target with values larger than 1 (#349)

  • Added support for negative targets in nDCG metric (#378)

  • Added None as reduction option in CosineSimilarity metric (#400)

  • Allowed passing labels in (n_samples, n_classes) to AveragePrecision (#386)

[0.5.0] - Changed
  • Moved psnr and ssim from functional.regression.* to functional.image.* (#382)

  • Moved image_gradient from functional.image_gradients to functional.image.gradients (#381)

  • Moved R2Score from regression.r2score to regression.r2 (#371)

  • Pearson metric now only store 6 statistics instead of all predictions and targets (#380)

  • Use torch.argmax instead of torch.topk when k=1 for better performance (#419)

  • Moved check for number of samples in R2 score to support single sample updating (#426)

[0.5.0] - Deprecated
  • Rename r2score >> r2_score and kldivergence >> kl_divergence in functional (#371)

  • Moved bleu_score from functional.nlp to functional.text.bleu (#360)

[0.5.0] - Removed
  • Removed restriction that threshold has to be in (0,1) range to support logit input ( #351 #401)

  • Removed restriction that preds could not be bigger than num_classes to support logit input (#357)

  • Removed module regression.psnr and regression.ssim (#382):

  • Removed (#379):

    • function functional.mean_relative_error

    • num_thresholds argument in BinnedPrecisionRecallCurve

[0.5.0] - Fixed
  • Fixed bug where classification metrics with average='macro' would lead to wrong result if a class was missing (#303)

  • Fixed weighted, multi-class AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 (#376)

  • Fixed that _forward_cache and _computed attributes are also moved to the correct device if metric is moved (#413)

  • Fixed calculation in IoU metric when using ignore_index argument (#328)

[0.4.1] - 2021-07-05

[0.4.1] - Changed
[0.4.1] - Fixed
  • Fixed DDP by is_sync logic to Metric (#339)

[0.4.0] - 2021-06-29

[0.4.0] - Added
  • Added Image-related metrics:

    • Fréchet inception distance (FID) (#213)

    • Kernel Inception Distance (KID) (#301)

    • Inception Score (#299)

    • KL divergence (#247)

  • Added Audio metrics: SNR, SI_SDR, SI_SNR (#292)

  • Added other metrics:

    • Cosine Similarity (#305)

    • Specificity (#210)

    • Mean Absolute Percentage error (MAPE) (#248)

  • Added add_metrics method to MetricCollection for adding additional metrics after initialization (#221)

  • Added pre-gather reduction in the case of dist_reduce_fx="cat" to reduce communication cost (#217)

  • Added better error message for AUROC when num_classes is not provided for multiclass input (#244)

  • Added support for unnormalized scores (e.g. logits) in Accuracy, Precision, Recall, FBeta, F1, StatScore, Hamming, ConfusionMatrix metrics (#200)

  • Added squared argument to MeanSquaredError for computing RMSE (#249)

  • Added is_differentiable property to ConfusionMatrix, F1, FBeta, Hamming, Hinge, IOU, MatthewsCorrcoef, Precision, Recall, PrecisionRecallCurve, ROC, StatScores (#253)

  • Added sync and sync_context methods for manually controlling when metric states are synced (#302)

[0.4.0] - Changed
  • Forward cache is reset when reset method is called (#260)

  • Improved per-class metric handling for imbalanced datasets for precision, recall, precision_recall, fbeta, f1, accuracy, and specificity (#204)

  • Decorated torch.jit.unused to MetricCollection forward (#307)

  • Renamed thresholds argument to binned metrics for manually controlling the thresholds (#322)

  • Extend typing (#324, #326, #327)

[0.4.0] - Deprecated
  • Deprecated functional.mean_relative_error, use functional.mean_absolute_percentage_error (#248)

  • Deprecated num_thresholds argument in BinnedPrecisionRecallCurve (#322)

[0.4.0] - Removed
  • Removed argument is_multiclass (#319)

[0.4.0] - Fixed
  • AUC can also support more dimensional inputs when all but one dimension are of size 1 (#242)

  • Fixed dtype of modular metrics after reset has been called (#243)

  • Fixed calculation in matthews_corrcoef to correctly match formula (#321)

[0.3.2] - 2021-05-10

[0.3.2] - Added
  • Added is_differentiable property:

    • To AUC, AUROC, CohenKappa and AveragePrecision (#178)

    • To PearsonCorrCoef, SpearmanCorrcoef, R2Score and ExplainedVariance (#225)

[0.3.2] - Changed
  • MetricCollection should return metrics with prefix on items(), keys() (#209)

  • Calling compute before update will now give warning (#164)

[0.3.2] - Removed
  • Removed numpy as direct dependency (#212)

[0.3.2] - Fixed
  • Fixed auc calculation and add tests (#197)

  • Fixed loading persisted metric states using load_state_dict() (#202)

  • Fixed PSNR not working with DDP (#214)

  • Fixed metric calculation with unequal batch sizes (#220)

  • Fixed metric concatenation for list states for zero-dim input (#229)

  • Fixed numerical instability in AUROC metric for large input (#230)

[0.3.1] - 2021-04-21

  • Cleaning remaining inconsistency and fix PL develop integration ( #191, #192, #193, #194 )

[0.3.0] - 2021-04-20

[0.3.0] - Added
  • Added BootStrapper to easily calculate confidence intervals for metrics (#101)

  • Added Binned metrics (#128)

  • Added metrics for Information Retrieval ((PL^5032)):

    • RetrievalMAP (PL^5032)

    • RetrievalMRR (#119)

    • RetrievalPrecision (#139)

    • RetrievalRecall (#146)

    • RetrievalNormalizedDCG (#160)

    • RetrievalFallOut (#161)

  • Added other metrics:

    • CohenKappa (#69)

    • MatthewsCorrcoef (#98)

    • PearsonCorrcoef (#157)

    • SpearmanCorrcoef (#158)

    • Hinge (#120)

  • Added average='micro' as an option in AUROC for multilabel problems (#110)

  • Added multilabel support to ROC metric (#114)

  • Added testing for half precision (#77, #135 )

  • Added AverageMeter for ad-hoc averages of values (#138)

  • Added prefix argument to MetricCollection (#70)

  • Added __getitem__ as metric arithmetic operation (#142)

  • Added property is_differentiable to metrics and test for differentiability (#154)

  • Added support for average, ignore_index and mdmc_average in Accuracy metric (#166)

  • Added postfix arg to MetricCollection (#188)

[0.3.0] - Changed
  • Changed ExplainedVariance from storing all preds/targets to tracking 5 statistics (#68)

  • Changed behaviour of confusionmatrix for multilabel data to better match multilabel_confusion_matrix from sklearn (#134)

  • Updated FBeta arguments (#111)

  • Changed reset method to use detach.clone() instead of deepcopy when resetting to default (#163)

  • Metrics passed as dict to MetricCollection will now always be in deterministic order (#173)

  • Allowed MetricCollection pass metrics as arguments (#176)

[0.3.0] - Deprecated
  • Rename argument is_multiclass -> multiclass (#162)

[0.3.0] - Removed
  • Prune remaining deprecated (#92)

[0.3.0] - Fixed
  • Fixed when _stable_1d_sort to work when n>=N (PL^6177)

  • Fixed _computed attribute not being correctly reset (#147)

  • Fixed to Blau score (#165)

  • Fixed backwards compatibility for logging with older version of pytorch-lightning (#182)

[0.2.0] - 2021-03-12

[0.2.0] - Changed
  • Decoupled PL dependency (#13)

  • Refactored functional - mimic the module-like structure: classification, regression, etc. (#16)

  • Refactored utilities - split to topics/submodules (#14)

  • Refactored MetricCollection (#19)

[0.2.0] - Removed
  • Removed deprecated metrics from PL base (#12, #15)

[0.1.0] - 2021-02-22

  • Added Accuracy metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the top_k parameter (PL^4838)

  • Added Accuracy metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the subset_accuracy parameter (PL^4838)

  • Added HammingDistance metric to compute the hamming distance (loss) (PL^4838)

  • Added StatScores metric to compute the number of true positives, false positives, true negatives and false negatives (PL^4839)

  • Added R2Score metric (PL^5241)

  • Added MetricCollection (PL^4318)

  • Added .clone() method to metrics (PL^4318)

  • Added IoU class interface (PL^4704)

  • The Recall and Precision metrics (and their functional counterparts recall and precision) can now be generalized to Recall@K and Precision@K with the use of top_k parameter (PL^4842)

  • Added compositional metrics (PL^5464)

  • Added AUC/AUROC class interface (PL^5479)

  • Added QuantizationAwareTraining callback (PL^5706)

  • Added ConfusionMatrix class interface (PL^4348)

  • Added multiclass AUROC metric (PL^4236)

  • Added PrecisionRecallCurve, ROC, AveragePrecision class metric (PL^4549)

  • Classification metrics overhaul (PL^4837)

  • Added F1 class metric (PL^4656)

  • Added metrics aggregation in Horovod and fixed early stopping (PL^3775)

  • Added persistent(mode) method to metrics, to enable and disable metric states being added to state_dict (PL^4482)

  • Added unification of regression metrics (PL^4166)

  • Added persistent flag to Metric.add_state (PL^4195)

  • Added classification metrics (PL^4043)

  • Added new Metrics API. (PL^3868, PL^3921)

  • Added EMB similarity (PL^3349)

  • Added SSIM metrics (PL^2671)

  • Added BLEU metrics (PL^2535)

Indices and tables


© Copyright Copyright (c) 2020-2023, Lightning-AI et al... Revision 520625c3.

Built with Sphinx using a theme provided by Read the Docs.

User Guide

Aggregation

Audio

Classification

Detection

Image

Multimodal

Nominal

Pairwise

Regression

Retrieval

Text

Wrappers

API Reference

Community

Read the Docs v: v1.1.2
Versions
latest
stable
v1.1.2
v1.1.1
v1.1.0
v1.0.3
v1.0.2
v1.0.1
v1.0.0
v0.11.4
v0.11.3
v0.11.2
v0.11.1
v0.11.0
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.3
v0.9.2
v0.9.1
v0.9.0
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.