Welcome to TorchMetrics¶
TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
A standardized interface to increase reproducibility
Reduces Boilerplate
Distributed-training compatible
Rigorously tested
Automatic accumulation over batches
Automatic synchronization between multiple devices
You can use TorchMetrics in any PyTorch model, or within PyTorch Lightning to enjoy the following additional benefits:
Your data will always be placed on the same device as your metrics
You can log
Metric
objects directly in Lightning to reduce even more boilerplate
Install TorchMetrics¶
For pip users
pip install torchmetrics
Or directly from conda
conda install -c conda-forge torchmetrics
Quick Start¶
TorchMetrics is a collection of 100+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
A standardized interface to increase reproducibility
Reduces Boilerplate
Distributed-training compatible
Rigorously tested
Automatic accumulation over batches
Automatic synchronization between multiple devices
You can use TorchMetrics in any PyTorch model, or within PyTorch Lightning to enjoy additional features:
This means that your data will always be placed on the same device as your metrics.
Native support for logging metrics in Lightning to reduce even more boilerplate.
Install¶
You can install TorchMetrics using pip or conda:
# Python Package Index (PyPI)
pip install torchmetrics
# Conda
conda install -c conda-forge torchmetrics
Eventually if there is a missing PyTorch wheel for your OS or Python version you can simply compile PyTorch from source:
# Optional if you do not need compile GPU support
export USE_CUDA=0 # just to keep it simple
# you can install the latest state from master
pip install git+https://github.com/pytorch/pytorch.git
# OR set a particular PyTorch release
pip install git+https://github.com/pytorch/pytorch.git@<release-tag>
# and finalize with installing TorchMetrics
pip install torchmetrics
Using TorchMetrics¶
Functional metrics¶
Similar to torch.nn, most metrics have both a class-based and a functional version.
The functional versions implement the basic operations required for computing each metric.
They are simple python functions that as input take torch.tensors
and return the corresponding metric as a torch.tensor
.
The code-snippet below shows a simple example for calculating the accuracy using the functional interface:
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target, task="multiclass", num_classes=5)
Module metrics¶
Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath. The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:
Accumulation of multiple batches
Automatic synchronization between multiple devices
Metric arithmetic
The code below shows how to use the class-based interface:
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
# Reseting internal state such that metric ready for new data
metric.reset()
Implementing your own metric¶
Implementing your own metric is as easy as subclassing a torch.nn.Module
. Simply, subclass Metric
and do the following:
Implement
__init__
where you callself.add_state
for every internal state that is needed for the metrics computationsImplement
update
method, where all logic that is necessary for updating metric states goImplement
compute
method, where the final metric computations happens
For practical examples and more info about implementing a metric, please see this page.
Development Environment¶
TorchMetrics provides a Devcontainer configuration for Visual Studio Code to use a Docker container as a pre-configured development environment.
This avoids struggles setting up a development environment and makes them reproducible and consistent.
Please follow the installation instructions and make yourself familiar with the container tutorials if you want to use them.
In order to use GPUs, you can enable them within the .devcontainer/devcontainer.json
file.
All TorchMetrics¶
Structure Overview¶
TorchMetrics is a Metrics API created for easy metric development and usage in PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of common metric implementations.
The metrics API provides update()
, compute()
, reset()
functions to the user. The metric base class inherits
torch.nn.Module
which allows us to call metric(...)
directly. The forward()
method of the base Metric
class
serves the dual purpose of calling update()
on its input and simultaneously returning the value of the metric over the
provided input.
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When .compute()
is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in .compute()
is applied to state information from all processes.
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
from torchmetrics.classification import BinaryAccuracy
train_accuracy = BinaryAccuracy()
valid_accuracy = BinaryAccuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
Note
Metrics contain internal states that keep track of the data seen so far. Do not mix metric states across training, validation and testing. It is highly recommended to re-initialize the metric per mode as shown in the examples above.
Note
Metric states are not added to the models state_dict
by default.
To change this, after initializing the metric, the method .persistent(mode)
can
be used to enable (mode=True
) or disable (mode=False
) this behaviour.
Note
Due to specialized logic around metric states, we in general do not recommend
that metrics are initialized inside other metrics (nested metrics), as this can lead
to weird behaviour. Instead consider subclassing a metric or use
torchmetrics.MetricCollection
.
Metrics and devices¶
Metrics are simple subclasses of Module
and their metric states behave
similar to buffers and parameters of modules. This means that metrics states should
be moved to the same device as the input of the metric:
from torchmetrics.classification import BinaryAccuracy
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = BinaryAccuracy().to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0
However, when properly defined inside a Module
or
LightningModule the metric will
be automatically moved to the same device as the module when using .to(device)
. Being
properly defined means that the metric is correctly identified as a child module of the
model (check .children()
attribute of the model). Therefore, metrics cannot be placed
in native python list
and dict
, as they will not be correctly identified
as child modules. Instead of list
use ModuleList
and instead of
dict
use ModuleDict
. Furthermore, when working with multiple metrics
the native MetricCollection module can also be used to wrap multiple metrics.
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy
class MyModule(torch.nn.Module):
def __init__(self):
...
# valid ways metrics will be identified as child modules
self.metric1 = BinaryAccuracy()
self.metric2 = nn.ModuleList(BinaryAccuracy())
self.metric3 = nn.ModuleDict({'accuracy': BinaryAccuracy()})
self.metric4 = MetricCollection([BinaryAccuracy()]) # torchmetrics build-in collection class
def forward(self, batch):
data, target = batch
preds = self(data)
...
val1 = self.metric1(preds, target)
val2 = self.metric2[0](preds, target)
val3 = self.metric3['accuracy'](preds, target)
val4 = self.metric4(preds, target)
You can always check which device the metric is located on using the .device property.
Metrics and memory management¶
As stated before, metrics have states and those states take up a certain amount of memory depending on the metric. In general metrics can be divided into two categories when we talk about memory management:
Metrics with tensor states: These metrics only have states that are insteances of
Tensor
. When these kind of metrics are updated the values of those tensors are updated. Importantly the size of the tensors are constant meaning that regardless of how much data is passed to the metric, its memory footprint will not change.Metrics with list states: These metrics have at least one state that is a list, which gets appended tensors as the metric is updated. Importantly the size of the list is therefore not constant and will grow as the metric is updated. The growth depends on the particular metric (some metrics only need to store a single value per sample, some much more).
You can always check the current metric state by accessing the .metric_state property, and checking if any of the states are lists.
import torch
from torchmetrics.regression import SpearmanCorrCoef
gen = torch.manual_seed(42)
metric = SpearmanCorrCoef()
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
{'preds': [tensor([0.8823, 0.9150])], 'target': [tensor([0.3829, 0.9593])]}
{'preds': [tensor([0.8823, 0.9150]), tensor([0.3904, 0.6009])], 'target': [tensor([0.3829, 0.9593]), tensor([0.2566, 0.7936])]}
In general we have a few recommendations for memory management:
When done with a metric, we always recommend calling the reset method. The reason for this being that the python garbage collector can struggle to totally clean the metric states if this is not done. In the worst case, this can lead to a memory leak if multiple instances of the same metric for different purposes are created in the same script.
Better to always try to reuse the same instance of a metric instead of initializing a new one. Calling the reset method returns the metric to its initial state, and can therefore be used to reuse the same instance. However, we still highly recommend to use different instances from training, validation and testing.
If only the results on a batch level are needed e.g no aggregation or alternatively if you have a small dataset that fits into iteration of evaluation, we can recommend using the functional API instead as it does not keep an internal state and memory is therefore freed after each call.
See Advanced metric settings for different advanced settings for controlling the memory footprint of metrics.
Metrics in Distributed Data Parallel (DDP) mode¶
When using metrics in Distributed Data Parallel (DDP)
mode, one should be aware that DDP will add additional samples to your dataset if the size of your dataset is
not equally divisible by batch_size * num_processors
. The added samples will always be replicates of datapoints
already in your dataset. This is done to secure an equal load for all processes. However, this has the consequence
that the calculated metric value will be slightly biased towards those replicated samples, leading to a wrong result.
During training and/or validation this may not be important, however it is highly recommended when evaluating the test dataset to only run on a single gpu or use a join context in conjunction with DDP to prevent this behaviour.
Metrics and 16-bit precision¶
Most metrics in our collection can be used with 16-bit precision (torch.half
) tensors. However, we have found
the following limitations:
In general
pytorch
had better support for 16-bit precision much earlier on GPU than CPU. Therefore, we recommend that anyone that want to use metrics with half precision on CPU, upgrade to atleast pytorch v1.6 where support for operations such as addition, subtraction, multiplication ect. was added.Some metrics does not work at all in half precision on CPU. We have explicitly stated this in their docstring, but they are also listed below:
You can always check the precision/dtype of the metric by checking the .dtype property.
Metric Arithmetics¶
Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.
For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary. It can now be done with:
first_metric = MyFirstMetric()
second_metric = MySecondMetric()
new_metric = first_metric + second_metric
new_metric.update(*args, **kwargs)
now calls update of first_metric
and second_metric
. It forwards
all positional arguments but forwards only the keyword arguments that are available in respective metric’s update
declaration. Similarly new_metric.compute()
now calls compute of first_metric
and second_metric
and
adds the results up. It is important to note that all implemented operations always returns a new metric object. This means
that the line first_metric == second_metric
will not return a bool indicating if first_metric
and second_metric
is the same metric, but will return a new metric that checks if the first_metric.compute() == second_metric.compute()
.
This pattern is implemented for the following operators (with a
being metrics and b
being metrics, tensors, integer or floats):
Addition (
a + b
)Bitwise AND (
a & b
)Equality (
a == b
)Floordivision (
a // b
)Greater Equal (
a >= b
)Greater (
a > b
)Less Equal (
a <= b
)Less (
a < b
)Matrix Multiplication (
a @ b
)Modulo (
a % b
)Multiplication (
a * b
)Inequality (
a != b
)Bitwise OR (
a | b
)Power (
a ** b
)Subtraction (
a - b
)True Division (
a / b
)Bitwise XOR (
a ^ b
)Absolute Value (
abs(a)
)Inversion (
~a
)Negative Value (
neg(a)
)Positive Value (
pos(a)
)Indexing (
a[0]
)
Note
Some of these operations are only fully supported from Pytorch v1.4 and onwards, explicitly we found:
add
, mul
, rmatmul
, rsub
, rmod
MetricCollection¶
In many cases it is beneficial to evaluate the model output by multiple metrics.
In this case the MetricCollection
class may come in handy. It accepts a sequence
of metrics and wraps these into a single callable metric class, with the same
interface as any other metric.
Example:
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
MulticlassAccuracy(num_classes=3, average="micro"),
MulticlassPrecision(num_classes=3, average="macro"),
MulticlassRecall(num_classes=3, average="macro")
])
print(metric_collection(preds, target))
{'MulticlassAccuracy': tensor(0.1250),
'MulticlassPrecision': tensor(0.0667),
'MulticlassRecall': tensor(0.1111)}
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule. In most cases we just have to replace self.log
with self.log_dict
.
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
class MyModule(LightningModule):
def __init__(self, num_classes: int):
super().__init__()
metrics = MetricCollection([
MulticlassAccuracy(num_classes), MulticlassPrecision(num_classes), MulticlassRecall(num_classes)
])
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')
def training_step(self, batch, batch_idx):
logits = self(x)
# ...
output = self.train_metrics(logits, y)
# use log_dict instead of log
# metrics are logged with keys: train_Accuracy, train_Precision and train_Recall
self.log_dict(output)
def validation_step(self, batch, batch_idx):
logits = self(x)
# ...
self.valid_metrics.update(logits, y)
def validation_epoch_end(self, outputs):
# use log_dict instead of log
# metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
output = self.valid_metrics.compute()
self.log_dict(output)
# remember to reset metrics at the end of the epoch
self.valid_metrics.reset()
Note
MetricCollection as default assumes that all the metrics in the collection have the same call signature. If this is not the case, input that should be given to different metrics can given as keyword arguments to the collection.
An additional advantage of using the MetricCollection
object is that it will
automatically try to reduce the computations needed by finding groups of metrics
that share the same underlying metric state. If such a group of metrics is found
only one of them is actually updated and the updated state will be broadcasted to
the rest of the metrics within the group. In the example above, this will lead to
a 2x-3x lower computational cost compared to disabling this feature in the case of
the validation metrics where only update
is called (this feature does not work
in combination with forward
). However, this speedup comes with a fixed cost upfront,
where the state-groups have to be determined after the first update. In case the groups
are known beforehand, these can also be set manually to avoid this extra cost of the
dynamic search. See the compute_groups argument in the class docs below for more
information on this topic.
- class torchmetrics.MetricCollection(metrics, *additional_metrics, prefix=None, postfix=None, compute_groups=True)[source]¶
MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
- Parameters:
metrics (
Union
[Metric
,Sequence
[Metric
],Dict
[str
,Metric
]]) –One of the following
list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
arguments: similar to passing in as a list, metrics passed in as arguments will use their metric class name as key for the output dict.
dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. Use this format if you want to chain together multiple of the same metric with different parameters. Note that the keys in the output dict will be sorted alphabetically.
prefix (
Optional
[str
]) – a string to append in front of the keys of the output dictpostfix (
Optional
[str
]) – a string to append after the keys of the output dictcompute_groups (
Union
[bool
,List
[List
[str
]]]) – By default the MetricCollection will try to reduce the computations needed for the metrics in the collection by checking if they belong to the same compute group. All metrics in a compute group share the same metric state and are therefore only different in their compute step e.g. accuracy, precision and recall can all be computed from the true positives/negatives and false positives/negatives. By default, this argument isTrue
which enables this feature. Set this argument to False for disabling this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself.
Note
The compute groups feature can significatly speedup the calculation of metrics under the right conditions. First, the feature is only available when calling the
update
method and not when callingforward
method due to the internal logic offorward
preventing this. Secondly, since we compute groups share metric states by reference, calling.items()
,.values()
etc. on the metric collection will break this reference and a copy of states are instead returned in this case (reference will be reestablished on the next call toupdate
).Note
Metric collections can be nested at initilization (see last example) but the output of the collection will still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection.
- Raises:
ValueError – If one of the elements of
metrics
is not an instance ofpl.metrics.Metric
.ValueError – If two elements in
metrics
have the samename
.ValueError – If
metrics
is not alist
,tuple
or adict
.ValueError – If
metrics
isdict
and additional_metrics are passed in.ValueError – If
prefix
is set and it is not a string.ValueError – If
postfix
is set and it is not a string.
- Example::
In the most basic case, the metrics can be passed in as a list or tuple. The keys of the output dict will be the same as the class name of the metric:
>>> from torch import tensor >>> from pprint import pprint >>> from torchmetrics import MetricCollection >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall >>> target = tensor([0, 2, 0, 2, 0, 1, 0, 2]) >>> preds = tensor([2, 1, 2, 0, 1, 2, 2, 2]) >>> metrics = MetricCollection([MulticlassAccuracy(num_classes=3, average='micro'), ... MulticlassPrecision(num_classes=3, average='macro'), ... MulticlassRecall(num_classes=3, average='macro')]) >>> metrics(preds, target) {'MulticlassAccuracy': tensor(0.1250), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
- Example::
Alternatively, metrics can be passed in as arguments. The keys of the output dict will be the same as the class name of the metric:
>>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'), ... MulticlassPrecision(num_classes=3, average='macro'), ... MulticlassRecall(num_classes=3, average='macro')) >>> metrics(preds, target) {'MulticlassAccuracy': tensor(0.1250), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
- Example::
If multiple of the same metric class (with different parameters) should be chained together, metrics can be passed in as a dict and the output dict will have the same keys as the input dict:
>>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'), ... 'macro_recall': MulticlassRecall(num_classes=3, average='macro')}) >>> same_metric = metrics.clone() >>> pprint(metrics(preds, target)) {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} >>> pprint(same_metric(preds, target)) {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
- Example::
Metric collections can also be nested up to a single time. The output of the collection will still be a single dict with the prefix and postfix arguments from the nested collection:
>>> metrics = MetricCollection([ ... MetricCollection([ ... MulticlassAccuracy(num_classes=3, average='macro'), ... MulticlassPrecision(num_classes=3, average='macro') ... ], postfix='_macro'), ... MetricCollection([ ... MulticlassAccuracy(num_classes=3, average='micro'), ... MulticlassPrecision(num_classes=3, average='micro') ... ], postfix='_micro'), ... ], prefix='valmetrics/') >>> pprint(metrics(preds, target)) {'valmetrics/MulticlassAccuracy_macro': tensor(0.1111), 'valmetrics/MulticlassAccuracy_micro': tensor(0.1250), 'valmetrics/MulticlassPrecision_macro': tensor(0.0667), 'valmetrics/MulticlassPrecision_micro': tensor(0.1250)}
- Example::
The compute_groups argument allow you to specify which metrics should share metric state. By default, this will automatically be derived but can also be set manually.
>>> metrics = MetricCollection( ... MulticlassRecall(num_classes=3, average='macro'), ... MulticlassPrecision(num_classes=3, average='macro'), ... MeanSquaredError(), ... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']] ... ) >>> metrics.update(preds, target) >>> pprint(metrics.compute()) {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)} >>> pprint(metrics.compute_groups) {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}
- add_metrics(metrics, *additional_metrics)[source]¶
Add new metrics to Metric Collection.
- Return type:
- clone(prefix=None, postfix=None)[source]¶
Make a copy of the metric collection.
- Parameters:
- Return type:
- items(keep_base=False, copy_state=True)[source]¶
Return an iterable of the ModuleDict key/value pairs.
- persistent(mode=True)[source]¶
Change if metric states should be saved to its state_dict after initialization.
- Return type:
- plot(val=None, ax=None, together=False)[source]¶
Plot a single or multiple values from the metric.
The plot method has two modes of operation. If argument together is set to False (default), the .plot method of each metric will be called individually and the result will be list of figures. If together is set to True, the values of all metrics will instead be plotted in the same figure.
- Parameters:
val (
Union
[Dict
,Sequence
[Dict
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Union
[Axes
,Sequence
[Axes
],None
]) – Either a single instance of matplotlib axis object or an sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects. If not provided, will create a new. If argument together is set to True, a single object is expected. If together is set to False, the number of axis objects needs to be the same lenght as the number of metrics in the collection.together (
bool
) – If True, will plot all metrics in the same axis. If False, will plot each metric in a separate
- Return type:
- Returns:
Either instal tupel of Figure and Axes object or an sequence of tuples with Figure and Axes object for each metric in the collection.
- Raises:
ModuleNotFoundError – If matplotlib is not installed
ValueError – If together is not an bool
ValueError – If ax is not an instance of matplotlib axis object or a sequence of matplotlib axis objects
>>> # Example plotting a single value >>> import torch >>> from torchmetrics import MetricCollection >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()]) >>> metrics.update(torch.rand(10), torch.randint(2, (10,))) >>> fig_ax_ = metrics.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics import MetricCollection >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()]) >>> values = [] >>> for _ in range(10): ... values.append(metrics(torch.rand(10), torch.randint(2, (10,)))) >>> fig_, ax_ = metrics.plot(values, together=True)
- set_dtype(dst_type)[source]¶
Transfer all metric state to specific dtype. Special version of standard type method.
Module vs Functional Metrics¶
The functional metrics follow the simple paradigm input in, output out. This means they don’t provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the Module-based interface. If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the Module interface.
Metrics and differentiability¶
Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. All modular metric classes have the property is_differentiable
that determines
if a metric is differentiable or not.
However, note that the cached state is detached from the computational graph and cannot be back-propagated. Not doing this would mean storing the computational graph for each update call, which can lead to out-of-memory errors. In practise this means that:
MyMetric.is_differentiable # returns True if metric is differentiable
metric = MyMetric()
val = metric(pred, target) # this value can be back-propagated
val = metric.compute() # this value cannot be back-propagated
A functional metric is differentiable if its corresponding modular metric is differentiable.
Metrics and hyperparameter optimization¶
If you want to directly optimize a metric it needs to support backpropagation (see section above).
However, if you are just interested in using a metric for hyperparameter tuning and are not sure
if the metric should be maximized or minimized, all modular metric classes have the higher_is_better
property that can be used to determine this:
# returns True because accuracy is optimal when it is maximized
torchmetrics.classification.Accuracy.higher_is_better
# returns False because the mean squared error is optimal when it is minimized
torchmetrics.MeanSquaredError.higher_is_better
Advanced metric settings¶
The following is a list of additional arguments that can be given to any metric class (in the **kwargs
argument)
that will alter how metric states are stored and synced.
If you are running metrics on GPU and are encountering that you are running out of GPU VRAM then the following argument can help:
compute_on_cpu
: will automatically move the metric states to cpu after callingupdate
, making sure that GPU memory is not filling up. The consequence will be that thecompute
method will be called on CPU instead of GPU. Only applies to metric states that are lists.compute_with_cache
: This argument indicates if the result after calling thecompute
method should be cached. By default this isTrue
meaning that repeated calls tocompute
(with no change to the metric state inbetween) does not recompute the metric but just returns the cache. By setting it toFalse
the metric will be recomputed every timecompute
is called, but it can also help clean up a bit of memory.
If you are running in a distributed environment, TorchMetrics will automatically take care of the distributed synchronization for you. However, the following three keyword arguments can be given to any metric class for further control over the distributed aggregation:
sync_on_compute
: This argument is anbool
that indicates if the metrics should automatically sync between devices whenever thecompute
method is called. By default this isTrue
, but by setting this toFalse
you can manually control when the synchronization happens.dist_sync_on_step
: This argument isbool
that indicates if the metric should synchronize between different devices every timeforward
is called. Setting this toTrue
is in general not recommended as synchronization is an expensive operation to do after each batch.process_group
: By default we synchronize across the world i.e. all processes being computed on. You can provide antorch._C._distributed_c10d.ProcessGroup
in this argument to specify exactly what devices should be synchronized over.dist_sync_fn
: By default we usetorch.distributed.all_gather()
to perform the synchronization between devices. Provide another callable function for this argument to perform custom distributed synchronization.
Plotting¶
Note
The visualization/plotting interface of Torchmetrics requires matplotlib
to be installed. Install with either
pip install matplotlib
or pip install 'torchmetrics[visual]'
. If the latter option is chosen the
Scienceplot package is also installed and all plots in
Torchmetrics will default to using that style.
Torchmetrics comes with build-in support for quick visualization of your metrics, by simply using the .plot
method
that all modular metrics implement. This method provides a consistent interface for basic plotting of all metrics.
metric = AnyMetricYouLike()
for _ in range(num_updates):
metric.update(preds[i], target[i])
fig, ax = metric.plot()
.plot
will always return two objects: fig
is an instance of Figure
which contains
figure level attributes and ax is an instance of Axes
that contains all the elements of the
plot. These two objects allow to change attributes of the plot after it is created. For example, if you want to make
the fontsize of the x-axis a bit bigger and give the figure a nice title and finally save it on the above example, it
could be do like this:
ax.set_fontsize(fs=20)
fig.set_title("This is a nice plot")
fig.savefig("my_awesome_plot.png")
If you want to include a Torchmetrics plot in a bigger figure that has subfigures and subaxes, all .plot
methods
support an optional ax argument where you can pass in the subaxes you want the plot to be inserted into:
# combine plotting of two metrics into one figure
fig, ax = plt.subplots(nrows=1, ncols=2)
metric1 = Metric1()
metric2 = Metric2()
for _ in range(num_updates):
metric1.update(preds[i], target[i])
metric2.update(preds[i], target[i])
metric1.plot(ax=ax[0])
metric2.plot(ax=ax[1])
Plotting a single step¶
At the most basic level the .plot
method can be used to plot the value from a single step. This can be done in two
ways:
* Either .plot
method is called with no input, and internally metric.compute()
is called and that value is plotted
* .plot
is called on a single returned value by the metric, for example from metric.forward()
In both cases it will generate a plot like this (Accuracy as an example):
metric = torchmetrics.Accuracy(task="binary")
for _ in range(num_updates):
metric.update(torch.rand(10,), torch.randint(2, (10,)))
fig, ax = metric.plot()

A single point plot is not that informative in itself, but if available we will try to include additional information
such as the lower and upper bounds the particular metric can take and if the metric should be minimized or maximized
to be optimal. This is true for all metrics that return a scalar tensor.
Some metrics return multiple values (such as an tensor with multiple elements or a dict of scalar tensors), and in
that case calling .plot
will return a figure similar to this:
metric = torchmetrics.Accuracy(task="multiclass", num_classes=3, average=None)
for _ in range(num_updates):
metric.update(torch.randint(3, (10,)), torch.randint(3, (10,)))
fig, ax = metric.plot()

Here, each element is assumed to be an independent metric and plotted as its own point for comparing. The above is true
for all metrics that return a scalar tensor, but if the metric returns a tensor with multiple elements then the
.plot
method will return a specialized plot for that particular metric. Take for example the ConfusionMatrix
metric:
metric = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=3)
for _ in range(num_updates):
metric.update(torch.randint(3, (10,)), torch.randint(3, (10,)))
fig, ax = metric.plot()

If you prefer to use the functional interface of Torchmetrics, you can also plot the values returned by the functional. However, you would still need to initialize the corresponding metric class to get the information about the metric:
plot_class = torchmetrics.Accuracy(task="multiclass", num_classes=3)
value = torchmetrics.functional.accuracy(
torch.randint(3, (10,)), torch.randint(3, (10,)), num_classes=3
)
fig, ax = plot_class.plot(value)
Plotting multi steps¶
In the above examples we have only plotted a single step/single value, but it is also possible to plot multiple steps
from the same metric. This is often the case when training a machine learning model, where you are tracking one or
multiple metrics that you want to plot as they are changing over time. This can be done by providing a sequence of outputs from
any metric, computed using metric.forward
or metric.compute
. For example, if we want to plot the accuracy of
a model over time, we could do it like this:
metric = torchmetrics.Accuracy(task="binary")
values = [ ]
for step in range(num_steps):
for _ in range(num_updates):
metric.update(preds(step), target(step))
values.append(metric.compute()) # save value
metric.reset()
fig, ax = metric.plot(values)

Do note that metrics that do not return simple scalar tensors, such as ConfusionMatrix, ROC that have specialized visualzation does not support plotting multiple steps, out of the box and the user needs to manually plot the values for each step.
Plotting a collection of metrics¶
MetricCollection
also supports .plot method and by default it works by just returning a collection of plots for
all its members. Thus, instead of returning a single (fig, ax) pair, calling .plot method of MetricCollection
will
return a sequence of such pairs, one for each member in the collection. In the following example we are forming a
collection of binary classification metrics and redirecting the output of .plot
to different subplots:
collection = torchmetrics.MetricCollection(
torchmetrics.Accuracy(task="binary"),
torchmetrics.Recall(task="binary"),
torchmetrics.Precision(task="binary"),
)
fig, ax = plt.subplots(nrows=1, ncols=3)
values = [ ]
for step in range(num_steps):
for _ in range(num_updates):
collection.update(preds(step), target(step))
values.append(collection.compute())
collection.reset()
collection.plot(val=values, ax=ax)

However, the plot
method of MetricCollection
also supports an additional argument called together
that will
automatically try to plot all the metrics in the collection together in the same plot (with appropriate labels). This
is only possible if all the metrics in the collection return a scalar tensor.
collection = torchmetrics.MetricCollection(
torchmetrics.Accuracy(task="binary"),
torchmetrics.Recall(task="binary"),
torchmetrics.Precision(task="binary"),
)
values = [ ]
fig, ax = plt.subplots(figsize=(6.8, 4.8))
for step in range(num_steps):
for _ in range(num_updates):
collection.update(preds(step), target(step))
values.append(collection.compute())
collection.reset()
collection.plot(val=values, together=True)

Advance example¶
In the following we are going to show how to use the .plot
method to create a more advanced plot. We are going to
combine the functionality of several metrics using MetricCollection
and plot them together. In addition we are going
to rely on MetricTracker
to keep track of the metrics over multiple steps.
# Define collection that is a mix of metrics that return a scalar tensors and not
confmat = torchmetrics.ConfusionMatrix(task="binary")
roc = torchmetrics.ROC(task="binary")
collection = torchmetrics.MetricCollection(
torchmetrics.Accuracy(task="binary"),
torchmetrics.Recall(task="binary"),
torchmetrics.Precision(task="binary"),
confmat,
roc,
)
# Define tracker over the collection to easy keep track of the metrics over multiple steps
tracker = torchmetrics.wrappers.MetricTracker(collection)
# Run "training" loop
for step in range(num_steps):
tracker.increment()
for _ in range(N):
tracker.update(preds(step), target(step))
# Extract all metrics from all steps
all_results = tracker.compute_all()
# Constuct a single figure with appropriate layout for all metrics
fig = plt.figure(layout="constrained")
ax1 = plt.subplot(2, 2, 1)
ax2 = plt.subplot(2, 2, 2)
ax3 = plt.subplot(2, 2, (3, 4))
# ConfusionMatrix and ROC we just plot the last step, notice how we call the plot method of those metrics
confmat.plot(val=all_results[-1]['BinaryConfusionMatrix'], ax=ax1)
roc.plot(all_results[-1]["BinaryROC"], ax=ax2)
# For the remainig we plot the full history, but we need to extract the scalar values from the results
scalar_results = [
{k: v for k, v in ar.items() if isinstance(v, torch.Tensor) and v.numel() == 1} for ar in all_results
]
tracker.plot(val=scalar_results, ax=ax3)

Implementing a Metric¶
To implement your own custom metric, subclass the base Metric
class and implement the following
methods:
__init__()
: Each state variable should be called usingself.add_state(...)
.update()
: Any code needed to update the state given any inputs to the metric.compute()
: Computes a final value from the state of the metric.
We provide the remaining interface, such as reset()
that will make sure to correctly reset all metric
states that have been added using add_state
. You should therefore not implement reset()
yourself.
Additionally, adding metric states with add_state
will make sure that states are correctly synchronized
in distributed settings (DDP). To see how metric states are synchronized across distributed processes,
refer to add_state()
docs from the base Metric
class.
Example implementation:
from torchmetrics import Metric
class MyAccuracy(Metric):
def __init__(self):
super().__init__()
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
Additionally you may want to set the class properties: is_differentiable, higher_is_better and full_state_update. Note that none of them are strictly required for the metric to work.
from torchmetrics import Metric
class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None
# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True
# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True
Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the .plot method. By default this method will raise NotImplementedError but can be implemented by the user to provide a custom plot for the metric. For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal ._plot method can be used, that provides the common plotting functionality for most metrics in torchmetrics.
from torchmetrics import Metric
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
class MyMetric(Metric):
...
def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
return self._plot(val, ax)
If the metric returns a more complex output, a custom implementation of the plot method is required. For more details on the plotting API, see the this page .
Internal implementation details¶
This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, TorchMetrics wraps the user defined update()
and compute()
method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling update()
does the
following internally:
Clears computed cache.
Calls user-defined
update()
.
Similarly, calling compute()
does the following internally:
Syncs metric states between processes.
Reduce gathered metric states.
Calls the user defined
compute()
method on the gathered metric states.Cache computed result.
From a user’s standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times compute
is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to update
.
forward
serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The forward()
method achieves this by combining calls
to update
, compute
and reset
. Depending on the class property full_state_update
, forward
can behave in two ways:
If
full_state_update
isTrue
it indicates that the metric duringupdate
requires access to the full metric state and we therefore need to do two calls toupdate
to secure that the metric is calculated correctlyCalls
update()
to update the global metric state (for accumulation over multiple batches)Caches the global state.
Calls
reset()
to clear global metric state.Calls
update()
to update local metric state.Calls
compute()
to calculate metric for current batch.Restores the global state.
If
full_state_update
isFalse
(default) the metric state of one batch is completly independent of the state of other batches, which means that we only need to callupdate
once.Caches the global state.
Calls
reset
the metric to its default stateCalls
update
to update the state with local batch statisticsCalls
compute
to calculate the metric for the current batchReduce the global state and batch state into a single state that becomes the new global state
If implementing your own metric, we recommend trying out the metric with full_state_update
class property set to
both True
and False
. If the results are equal, then setting it to False
will usually give the best performance.
- class torchmetrics.Metric(**kwargs)[source]
Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are *
add_state()
*forward()
*reset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten *
update()
*compute()
- Parameters:
kwargs (
Any
) –additional keyword arguments, see Advanced metric settings for more info.
compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
dist_sync_on_step: If metric state should synchronize on
forward()
. Default isFalse
process_group: The process group on which the synchronization is called. Default is the world.
dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls
torch.distributed.all_gather
internally.distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
.sync_on_compute: If metric state should synchronize when
compute
is called. Default isTrue
compute_with_cache: If results from
compute
should be cached. Default isFalse
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]
Add metric state variable. Only used by subclasses.
Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if
name
is"my_state"
then its value can be accessed from an instancemetric
asmetric.my_state
. Metric states behave like buffers and parameters ofModule
as they are also updated when.to()
is called. Unlike parameters and buffers, metric states are not by default saved in the modulesstate_dict
.- Parameters:
name (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default (
Union
[list
,Tensor
]) – Default value of the state; can either be aTensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
,"cat"
,"min"
or"max"
we will usetorch.sum
,torch.mean
,torch.cat
,torch.min
andtorch.max`
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
- Return type:
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
Tensor
, the synced value will be a stackedTensor
across the process dimension if the metric state was aTensor
. The originalTensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.- Raises:
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,"min"
,"max"
orNone
.
- abstract compute()[source]
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- Return type:
- double()[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- float()[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- forward(*args, **kwargs)[source]
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Parameters:
- Return type:
- Returns:
The output of the
compute
method evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forward
is called again.
- half()[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- persistent(mode=False)[source]
Change post-init if metric states should be saved to its state_dict.
- Return type:
- set_dtype(dst_type)[source]
Transfer all metric state to specific dtype. Special version of standard type method.
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Get the current state of metric as an dictionary.
- Parameters:
destination (
Optional
[Dict
[str
,Any
]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, anOrderedDict
will be created and returned.prefix (
str
) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.keep_vars (
bool
) – by default theTensor
returned in the state dict are detached from autograd. If set toTrue
, detaching will not be performed.
- Return type:
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters:
dist_sync_fn (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Raises:
TorchMetricsUserError – If the metric is already synced and
sync
is called again.- Return type:
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]
Context manager to synchronize states.
This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the syncronized state.
- Parameters:
dist_sync_fn (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync (
bool
) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type:
- type(dst_type)[source]
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- unsync(should_unsync=True)[source]
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]
Override this method to update the state variables of your metric class.
- Return type:
- property device: device
Return the device of the metric.
- property update_called: bool
Returns True if update or forward has been called initialization or last reset.
- property update_count: int
Get the number of times update and/or forward has been called since initialization or last reset.
Contributing your metric to TorchMetrics¶
Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics
as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation
and tests gets formatted in the following way:
Start by reading our contribution guidelines.
First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should be put into a single file placed under
torchmetrics/functional/"domain"/"new_metric".py
wheredomain
is the type of metric (classification, regression, nlp etc) andnew_metric
is the name of the metric. In this file, there should be the following three functions:
_new_metric_update(...)
: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.
_new_metric_compute(...)
: all remaining logic.
new_metric(...)
: essentially wraps the_update
and_compute
private functions into one public function that makes up the functional interface for the metric.Note
The functional accuracy metric is a great example of this division of logic.
In a corresponding file placed in
torchmetrics/"domain"/"new_metric".py
create the module interface:
Create a new module metric by subclassing
torchmetrics.Metric
.In the
__init__
of the module callself.add_state
for as many metric states are needed for the metric to proper accumulate metric statistics.The module interface should essentially call the private
_new_metric_update(...)
in its update method and similarly the_new_metric_compute(...)
function in itscompute
. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.Note
The module Accuracy metric that corresponds to the above functional example showcases these steps.
Remember to add binding to the different relevant
__init__
files.Testing is key to keeping
torchmetrics
trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn
,scipy
etc).
Create a testing file in
unittests/"domain"/test_"new_metric".py
. Only one file is needed as it is intended to test both the functional and module interface.In that file, start by defining a number of test inputs that your metric should be evaluated on.
Create a testclass
class NewMetric(MetricTester)
that inherits fromtests.helpers.testers.MetricTester
. This testclass should essentially implement thetest_"new_metric"_class
andtest_"new_metric"_fn
methods that respectively tests the module interface and the functional interface.The testclass should be parameterized (using
@pytest.mark.parametrize
) by the different test inputs defined initially. Additionally, thetest_"new_metric"_class
method should also be parameterized with anddp
parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these such that different combinations of inputs and parameters gets tested.(optional) If your metric raises any exception, please add tests that showcase this.
Note
The test file for accuracy metric shows how to implement such tests.
If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]
TorchMetrics in PyTorch Lightning¶
TorchMetrics was originally created as part of PyTorch Lightning, a powerful deep learning research framework designed for scaling models without boilerplate.
Note
TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks up-to-date for the best experience.
While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits:
Modular metrics are automatically placed on the correct device when properly defined inside a LightningModule. This means that your data will always be placed on the same device as your metrics. No need to call
.to(device)
anymore!Native support for logging metrics in Lightning using self.log inside your LightningModule.
The
.reset()
method of the metric will automatically be called at the end of an epoch.
The example below shows how to use a metric in your LightningModule:
class MyModel(LightningModule):
def __init__(self, num_classes):
...
self.accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.accuracy(preds, y)
self.log('train_acc_step', self.accuracy)
...
def on_train_epoch_end(self):
# log epoch metric
self.log('train_acc_epoch', self.accuracy)
Metric logging in Lightning happens through the self.log
or self.log_dict
method. Both methods only support the
logging of scalar-tensors. While the vast majority of metrics in torchmetrics returns a scalar tensor, some metrics
such as ConfusionMatrix
,
ROC
,
MeanAveragePrecision
, ROUGEScore
return
outputs that are non-scalar tensors (often dicts or list of tensors) and should therefore be dealt with separately.
For info about the return type and shape please look at the documentation for the compute
method for each metric
you want to log.
Logging TorchMetrics¶
Logging metrics can be done in two ways: either logging the metric object directly or the computed metric values.
When Metric
objects, which return a scalar tensor are logged directly in Lightning using the
LightningModule self.log
method, Lightning will log the metric based on on_step
and on_epoch
flags present in self.log(...)
. If
on_epoch
is True, the logger automatically logs the end of epoch metric value by calling .compute()
.
Note
sync_dist
, sync_dist_group
and reduce_fx
flags from self.log(...)
don’t affect the metric logging
in any manner. The metric class contains its own distributed synchronization logic.
This however is only true for metrics that inherit the base class Metric
,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.
class MyModule(LightningModule):
def __init__(self, num_classes):
...
self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
self.train_acc(preds, y)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
As an alternative to logging the metric object and letting Lightning take care of when to reset the metric etc. you can also manually log the output of the metrics.
class MyModule(LightningModule):
def __init__(self):
...
self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
batch_value = self.train_acc(preds, y)
self.log('train_acc_step', batch_value)
def on_train_epoch_end(self):
self.train_acc.reset()
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc.update(logits, y)
def on_validation_epoch_end(self, outputs):
self.log('valid_acc_epoch', self.valid_acc.compute())
self.valid_acc.reset()
Note that logging metrics this way will require you to manually reset the metrics at the end of the epoch yourself. In general, we recommend logging the metric object to make sure that metrics are correctly computed and reset. Additionally, we highly recommend that the two ways of logging are not mixed as it can lead to wrong results.
Note
When using any Modular metric, calling self.metric(...)
or self.metric.forward(...)
serves the dual purpose
of calling self.metric.update()
on its input and simultaneously returning the metric value over the provided
input. So if you are logging a metric only on epoch-level (as in the example above), it is recommended to call
self.metric.update()
directly to avoid the extra computation.
class MyModule(LightningModule):
def __init__(self, num_classes):
...
self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc.update(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
Common Pitfalls¶
The following contains a list of pitfalls to be aware of:
Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders, it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds for using seperate metrics for training, validation and testing.
class MyModule(LightningModule):
def __init__(self, num_classes):
...
self.val_acc = nn.ModuleList(
[torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes) for _ in range(2)]
)
def val_dataloader(self):
return [DataLoader(...), DataLoader(...)]
def validation_step(self, batch, batch_idx, dataloader_idx):
x, y = batch
preds = self(x)
...
self.val_acc[dataloader_idx](preds, y)
self.log('val_acc', self.val_acc[dataloader_idx])
Mixing the two logging methods by calling
self.log("val", self.metric)
in{training}/{val}/{test}_step
method and then callingself.log("val", self.metric.compute())
in the corresponding{training}/{val}/{test}_epoch_end
method. Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading to errors or nonsense results.Calling
self.log("val", self.metric(preds, target))
with the intention of logging the metric object. Becauseself.metric(preds, target)
corresponds to calling the forward method, this will return a tensor and not the metric object. Such logging will be wrong in this case. Instead it is important to seperate into seperate lines:
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.accuracy(preds, y) # compute metrics
self.log('train_acc_step', self.accuracy) # log metric object
Using
MetricTracker
wrapper with Lightning is a special case, because the wrapper in itself is not a metric i.e. it does not inherit from the baseMetric
class but instead fromModuleList
. Thus, to log the output of this metric one needs to manually log the returned values (not the object) usingself.log
and for epoch level logging this should be done in the appropriateon_***_epoch_end
method.
Concatenation¶
Module Interface¶
- class torchmetrics.aggregation.CatMetric(nan_strategy='warn', **kwargs)[source]¶
Concatenate a stream of values.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with concatenated values over all input received
- Parameters:
nan_strategy (
Union
[str
,float
]) – options: -'error'
: if any nan values are encounted will give a RuntimeError -'warn'
: if any nan values are encounted will give a warning and continue -'ignore'
: all nan values are silently removed - a float: if a float is provided will impude any nan values with this valuekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torch import tensor >>> from torchmetrics.aggregation import CatMetric >>> metric = CatMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor([1., 2., 3.])
Maximum¶
Module Interface¶
- class torchmetrics.aggregation.MaxMetric(nan_strategy='warn', **kwargs)[source]¶
Aggregate a stream of value into their maximum value.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with aggregated maximum value over all inputs received
- Parameters:
nan_strategy (
Union
[str
,float
]) – options: -'error'
: if any nan values are encounted will give a RuntimeError -'warn'
: if any nan values are encounted will give a warning and continue -'ignore'
: all nan values are silently removed - a float: if a float is provided will impude any nan values with this valuekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torch import tensor >>> from torchmetrics.aggregation import MaxMetric >>> metric = MaxMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor(3.)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.aggregation import MaxMetric >>> metric = MaxMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.aggregation import MaxMetric >>> metric = MaxMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric(i)) >>> fig_, ax_ = metric.plot(values)
Mean¶
Module Interface¶
- class torchmetrics.aggregation.MeanMetric(nan_strategy='warn', **kwargs)[source]¶
Aggregate a stream of value into their mean value.
As input to
forward
andupdate
the metric accepts the following inputvalue
(float
orTensor
): a single float or an tensor of float values with arbitary shape(...,)
.weight
(float
orTensor
): a single float or an tensor of float value with arbitary shape(...,)
. Needs to be broadcastable with the shape ofvalue
tensor.
As output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with aggregated (weighted) mean over all inputs received
- Parameters:
nan_strategy (
Union
[str
,float
]) –- options:
'error'
: if any nan values are encounted will give a RuntimeError'warn'
: if any nan values are encounted will give a warning and continue'ignore'
: all nan values are silently removeda float: if a float is provided will impude any nan values with this value
kwargs: Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torchmetrics.aggregation import MeanMetric >>> metric = MeanMetric() >>> metric.update(1) >>> metric.update(torch.tensor([2, 3])) >>> metric.compute() tensor(2.)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.aggregation import MeanMetric >>> metric = MeanMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.aggregation import MeanMetric >>> metric = MeanMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric([i, i+1])) >>> fig_, ax_ = metric.plot(values)
Minimum¶
Module Interface¶
- class torchmetrics.aggregation.MinMetric(nan_strategy='warn', **kwargs)[source]¶
Aggregate a stream of value into their minimum value.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with aggregated minimum value over all inputs received
- Parameters:
nan_strategy (
Union
[str
,float
]) – options: -'error'
: if any nan values are encounted will give a RuntimeError -'warn'
: if any nan values are encounted will give a warning and continue -'ignore'
: all nan values are silently removed - a float: if a float is provided will impude any nan values with this valuekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torch import tensor >>> from torchmetrics.aggregation import MinMetric >>> metric = MinMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor(1.)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.aggregation import MinMetric >>> metric = MinMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.aggregation import MinMetric >>> metric = MinMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric(i)) >>> fig_, ax_ = metric.plot(values)
Running Mean¶
Module Interface¶
- class torchmetrics.aggregation.RunningMean(window=5, nan_strategy='warn', **kwargs)[source]¶
Aggregate a stream of value into their mean over a running window.
Using this metric compared to MeanMetric allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don’t want to wait for the whole training to finish to get epoch level estimates.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with aggregated sum over all inputs received
- Parameters:
window (
int
) – The size of the running window.nan_strategy (
Union
[str
,float
]) – options: -'error'
: if any nan values are encounted will give a RuntimeError -'warn'
: if any nan values are encounted will give a warning and continue -'ignore'
: all nan values are silently removed - a float: if a float is provided will impude any nan values with this valuekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torch import tensor >>> from torchmetrics.aggregation import RunningMean >>> metric = RunningMean(window=3) >>> for i in range(6): ... current_val = metric(tensor([i])) ... running_val = metric.compute() ... total_val = tensor(sum(list(range(i+1)))) / (i+1) # total mean over all samples ... print(f"{current_val=}, {running_val=}, {total_val=}") current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.) current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000) current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.) current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000) current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.) current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000)
Running Sum¶
Module Interface¶
- class torchmetrics.aggregation.RunningSum(window=5, nan_strategy='warn', **kwargs)[source]¶
Aggregate a stream of value into their sum over a running window.
Using this metric compared to SumMetric allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don’t want to wait for the whole training to finish to get epoch level estimates.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with aggregated sum over all inputs received
- Parameters:
window (
int
) – The size of the running window.nan_strategy (
Union
[str
,float
]) – options: -'error'
: if any nan values are encounted will give a RuntimeError -'warn'
: if any nan values are encounted will give a warning and continue -'ignore'
: all nan values are silently removed - a float: if a float is provided will impude any nan values with this valuekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torch import tensor >>> from torchmetrics.aggregation import RunningSum >>> metric = RunningSum(window=3) >>> for i in range(6): ... current_val = metric(tensor([i])) ... running_val = metric.compute() ... total_val = tensor(sum(list(range(i+1)))) # total sum over all samples ... print(f"{current_val=}, {running_val=}, {total_val=}") current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0) current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1) current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3) current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6) current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10) current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)
Sum¶
Module Interface¶
- class torchmetrics.aggregation.SumMetric(nan_strategy='warn', **kwargs)[source]¶
Aggregate a stream of value into their sum.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
agg
(Tensor
): scalar float tensor with aggregated sum over all inputs received
- Parameters:
nan_strategy (
Union
[str
,float
]) – options: -'error'
: if any nan values are encounted will give a RuntimeError -'warn'
: if any nan values are encounted will give a warning and continue -'ignore'
: all nan values are silently removed - a float: if a float is provided will impude any nan values with this valuekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
nan_strategy
is not one oferror
,warn
,ignore
or a float
Example
>>> from torch import tensor >>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> metric.update(1) >>> metric.update(tensor([2, 3])) >>> metric.compute() tensor(6.)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> metric.update([1, 2, 3]) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.aggregation import SumMetric >>> metric = SumMetric() >>> values = [ ] >>> for i in range(10): ... values.append(metric([i, i+1])) >>> fig_, ax_ = metric.plot(values)
Complex Scale-Invariant Signal-to-Noise Ratio (C-SI-SNR)¶
Module Interface¶
- class torchmetrics.audio.ComplexScaleInvariantSignalNoiseRatio(zero_mean=False, **kwargs)[source]¶
Calculate Complex scale-invariant signal-to-noise ratio (C-SI-SNR) metric for evaluating quality of audio.
As input to forward and update the metric accepts the following input
preds
(Tensor
): real float tensor with shape(...,frequency,time,2)
or complex float tensor with shape(..., frequency,time)
target
(Tensor
): real float tensor with shape(...,frequency,time,2)
or complex float tensor with shape(..., frequency,time)
As output of forward and compute the metric returns the following output
c_si_snr
(Tensor
): float scalar tensor with average C-SI-SNR value over samples
- Parameters:
zero_mean (
bool
) – if to zero mean target and preds or notkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
zero_mean
is not an boolTypeError – If
preds
is not the shape (…, frequency, time, 2) (after being converted to real if it is complex). Ifpreds
andtarget
does not have the same shape.
Example
>>> import torch >>> from torch import tensor >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio >>> g = torch.manual_seed(1) >>> preds = torch.randn((1,257,100,2)) >>> target = torch.randn((1,257,100,2)) >>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio() >>> c_si_snr(preds, target) tensor(-63.4849)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio >>> metric = ComplexScaleInvariantSignalNoiseRatio() >>> metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio >>> metric = ComplexScaleInvariantSignalNoiseRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(1,257,100,2), torch.rand(1,257,100,2))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.complex_scale_invariant_signal_noise_ratio(preds, target, zero_mean=False)[source]¶
Complex scale-invariant signal-to-noise ratio (C-SI-SNR).
- Parameters:
preds (
Tensor
) – real float tensor with shape(...,frequency,time,2)
or complex float tensor with shape(..., frequency,time)
target (
Tensor
) – real float tensor with shape(...,frequency,time,2)
or complex float tensor with shape(..., frequency,time)
zero_mean (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metrics
- Return type:
- Returns:
Float tensor with shape
(...,)
of C-SI-SNR values per sample- Raises:
RuntimeError – If
preds
is not the shape (…,frequency,time,2) (after being converted to real if it is complex). Ifpreds
andtarget
does not have the same shape.
Example
>>> import torch >>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio >>> g = torch.manual_seed(1) >>> preds = torch.randn((1,257,100,2)) >>> target = torch.randn((1,257,100,2)) >>> complex_scale_invariant_signal_noise_ratio(preds, target) tensor([-63.4849])
Perceptual Evaluation of Speech Quality (PESQ)¶
Module Interface¶
- class torchmetrics.audio.pesq.PerceptualEvaluationSpeechQuality(fs, mode, n_processes=1, **kwargs)[source]¶
Calculate Perceptual Evaluation of Speech Quality (PESQ).
It’s a recognized industry standard for audio quality that takes into considerations characteristics such as: audio sharpness, call volume, background noise, clipping, audio interference ect. PESQ returns a score between -0.5 and 4.5 with the higher scores indicating a better quality.
This metric is a wrapper for the pesq package. Note that input will be moved to
cpu
to perform the metric calculation.As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
pesq
(Tensor
): float tensor with shape(...,)
of PESQ value per sample
Note
using this metrics requires you to have
pesq
install. Either install aspip install torchmetrics[audio]
orpip install pesq
.pesq
will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstallpesq
.- Parameters:
fs (
int
) – sampling frequency, should be 16000 or 8000 (Hz)mode (
str
) –'wb'
(wide-band) or'nb'
(narrow-band)keep_same_device – whether to move the pesq value to the device of preds
n_processes (
int
) – integer specifiying the number of processes to run in parallel for the metric calculation. Only applies to batches of data and ifmultiprocessing
package is installed.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
pesq
package is not installedValueError – If
fs
is not either8000
or16000
ValueError – If
mode
is not either"wb"
or"nb"
Example
>>> import torch >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> nb_pesq = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> nb_pesq(preds, target) tensor(2.2076) >>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb') >>> wb_pesq(preds, target) tensor(1.7359)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(8000), torch.rand(8000))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality(preds, target, fs, mode, keep_same_device=False, n_processes=1)[source]¶
Calculate Perceptual Evaluation of Speech Quality (PESQ).
It’s a recognized industry standard for audio quality that takes into considerations characteristics such as: audio sharpness, call volume, background noise, clipping, audio interference ect. PESQ returns a score between -0.5 and 4.5 with the higher scores indicating a better quality.
This metric is a wrapper for the pesq package. Note that input will be moved to cpu to perform the metric calculation.
Note
using this metrics requires you to have
pesq
install. Either install aspip install torchmetrics[audio]
orpip install pesq
. Note thatpesq
will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstallpesq
.- Parameters:
preds (
Tensor
) – float tensor with shape(...,time)
target (
Tensor
) – float tensor with shape(...,time)
fs (
int
) – sampling frequency, should be 16000 or 8000 (Hz)mode (
str
) –'wb'
(wide-band) or'nb'
(narrow-band)keep_same_device (
bool
) – whether to move the pesq value to the device of predsn_processes (
int
) – integer specifiying the number of processes to run in parallel for the metric calculation. Only applies to batches of data and ifmultiprocessing
package is installed.
- Return type:
- Returns:
Float tensor with shape
(...,)
of PESQ values per sample- Raises:
ModuleNotFoundError – If
pesq
package is not installedValueError – If
fs
is not either8000
or16000
ValueError – If
mode
is not either"wb"
or"nb"
RuntimeError – If
preds
andtarget
do not have the same shape
Example
>>> from torch import randn >>> from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality >>> g = torch.manual_seed(1) >>> preds = randn(8000) >>> target = randn(8000) >>> perceptual_evaluation_speech_quality(preds, target, 8000, 'nb') tensor(2.2076) >>> perceptual_evaluation_speech_quality(preds, target, 16000, 'wb') tensor(1.7359)
Permutation Invariant Training (PIT)¶
Module Interface¶
- class torchmetrics.audio.PermutationInvariantTraining(metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]¶
Calculate Permutation invariant training (PIT).
This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(batch_size,num_speakers,...)
target
(Tensor
): float tensor with shape(batch_size,num_speakers,...)
As output of forward and compute the metric returns the following output
pesq
(Tensor
): float scalar tensor with average PESQ value over samples
- Parameters:
metric_func (
Callable
) –a metric function accept a batch of target and estimate.
if mode`==’speaker-wise’, then ``metric_func(preds[:, i, …], target[:, j, …])` is called and expected to return a batch of metric tensors
(batch,)
;if mode`==’permutation-wise’, then ``metric_func(preds[:, p, …], target[:, :, …])` is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors
(batch,)
;mode (
Literal
['speaker-wise'
,'permutation-wise'
]) – can be ‘speaker-wise’ or ‘permutation-wise’.eval_func (
Literal
['max'
,'min'
]) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.kwargs (
Any
) – Additional keyword arguments for either themetric_func
or distributed communication, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> _ = torch.manual_seed(42) >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) tensor(-2.1065)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.permutation_invariant_training(preds, target, metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]¶
Calculate Permutation invariant training (PIT).
This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.
- Parameters:
preds (
Tensor
) – float tensor with shape(batch_size,num_speakers,...)
target (
Tensor
) – float tensor with shape(batch_size,num_speakers,...)
metric_func (
Callable
) –a metric function accept a batch of target and estimate. if mode`==’speaker-wise’, then ``metric_func(preds[:, i, …], target[:, j, …])` is called and expected to return a batch of metric tensors
(batch,)
;if mode`==’permutation-wise’, then ``metric_func(preds[:, p, …], target[:, :, …])` is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors
(batch,)
;mode (
Literal
['speaker-wise'
,'permutation-wise'
]) – can be ‘speaker-wise’ or ‘permutation-wise’.eval_func (
Literal
['max'
,'min'
]) – the function to find the best permutation, can be'min'
or'max'
, i.e. the smaller the better or the larger the better.kwargs (
Any
) – Additional args for metric_func
- Return type:
- Returns:
Tuple of two float tensors. First tensor with shape
(batch,)
contains the best metric value for each sample and second tensor with shape(batch,)
contains the best permutation.
Example
>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = permutation_invariant_training( ... preds, target, scale_invariant_signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> best_metric tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)¶
Module Interface¶
- class torchmetrics.audio.ScaleInvariantSignalDistortionRatio(zero_mean=False, **kwargs)[source]¶
Scale-invariant signal-to-distortion ratio (SI-SDR).
The SI-SDR value is in general considered an overall measure of how good a source sound.
As input to forward and update the metric accepts the following input
preds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
si_sdr
(Tensor
): float scalar tensor with average SI-SDR value over samples
- Parameters:
zero_mean (
bool
) – if to zero mean target and preds or notkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
TypeError – if target and preds have a different shape
Example
>>> from torch import tensor >>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio >>> target = tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) >>> si_sdr = ScaleInvariantSignalDistortionRatio() >>> si_sdr(preds, target) tensor(18.4030)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio >>> target = torch.randn(5) >>> preds = torch.randn(5) >>> metric = ScaleInvariantSignalDistortionRatio() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import ScaleInvariantSignalDistortionRatio >>> target = torch.randn(5) >>> preds = torch.randn(5) >>> metric = ScaleInvariantSignalDistortionRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.scale_invariant_signal_distortion_ratio(preds, target, zero_mean=False)[source]¶
Scale-invariant signal-to-distortion ratio (SI-SDR).
The SI-SDR value is in general considered an overall measure of how good a source sound.
- Parameters:
- Return type:
- Returns:
Float tensor with shape
(...,)
of SDR values per sample- Raises:
RuntimeError – If
preds
andtarget
does not have the same shape
Example
>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> scale_invariant_signal_distortion_ratio(preds, target) tensor(18.4030)
Scale-Invariant Signal-to-Noise Ratio (SI-SNR)¶
Module Interface¶
- class torchmetrics.audio.ScaleInvariantSignalNoiseRatio(**kwargs)[source]¶
Calculate Scale-invariant signal-to-noise ratio (SI-SNR) metric for evaluating quality of audio.
As input to forward and update the metric accepts the following input
preds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
si_snr
(Tensor
): float scalar tensor with average SI-SNR value over samples
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.- Raises:
TypeError – if target and preds have a different shape
Example
>>> import torch >>> from torch import tensor >>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio >>> target = tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) >>> si_snr = ScaleInvariantSignalNoiseRatio() >>> si_snr(preds, target) tensor(15.0918)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio >>> metric = ScaleInvariantSignalNoiseRatio() >>> metric.update(torch.rand(4), torch.rand(4)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio >>> metric = ScaleInvariantSignalNoiseRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(4), torch.rand(4))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.scale_invariant_signal_noise_ratio(preds, target)[source]¶
Scale-invariant signal-to-noise ratio (SI-SNR).
- Parameters:
- Return type:
- Returns:
Float tensor with shape
(...,)
of SI-SNR values per sample- Raises:
RuntimeError – If
preds
andtarget
does not have the same shape
Example
>>> import torch >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> scale_invariant_signal_noise_ratio(preds, target) tensor(15.0918)
Short-Time Objective Intelligibility (STOI)¶
Module Interface¶
- class torchmetrics.audio.stoi.ShortTimeObjectiveIntelligibility(fs, extended=False, **kwargs)[source]¶
Calculate STOI (Short-Time Objective Intelligibility) metric for evaluating speech signals.
Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due to additive noise, single-/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech intelligibility. Description taken from Cees Taal’s website and for further defails see STOI ref1 and STOI ref2.
This metric is a wrapper for the pystoi package. As the implementation backend implementation only supports calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being moved back to the original device.
As input to forward and update the metric accepts the following input
preds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
stoi
(Tensor
): float scalar tensor
Note
using this metrics requires you to have
pystoi
install. Either install aspip install torchmetrics[audio]
orpip install pystoi
.- Parameters:
fs (
int
) – sampling frequency (Hz)extended (
bool
) – whether to use the extended STOI described in STOI ref3.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
pystoi
package is not installed
Example
>>> import torch >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> stoi = ShortTimeObjectiveIntelligibility(8000, False) >>> stoi(preds, target) tensor(-0.0100)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> metric = ShortTimeObjectiveIntelligibility(8000, False) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import ShortTimeObjectiveIntelligibility >>> metric = ShortTimeObjectiveIntelligibility(8000, False) >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.stoi.short_time_objective_intelligibility(preds, target, fs, extended=False, keep_same_device=False)[source]¶
Calculate STOI (Short-Time Objective Intelligibility) metric for evaluating speech signals.
Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due to additive noise, single-/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech intelligibility. Description taken from Cees Taal’s website and for further defails see STOI ref1 and STOI ref2.
This metric is a wrapper for the pystoi package. As the implementation backend implementation only supports calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being moved back to the original device.
Note
using this metrics requires you to have
pystoi
install. Either install aspip install torchmetrics[audio]
orpip install pystoi
- Parameters:
- Return type:
- Returns:
stoi value of shape […]
- Raises:
ModuleNotFoundError – If
pystoi
package is not installedRuntimeError – If
preds
andtarget
does not have the same shape
Example
>>> import torch >>> from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> short_time_objective_intelligibility(preds, target, 8000).float() tensor(-0.0100)
Signal to Distortion Ratio (SDR)¶
Module Interface¶
- class torchmetrics.audio.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]¶
Calculate Signal to Distortion Ratio (SDR) metric.
See SDR ref1 and SDR ref2 for details on the metric.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
sdr
(Tensor
): float scalar tensor with average SDR value over samples
- Parameters:
use_cg_iter (
Optional
[int
]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires thatfast-bss-eval
is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length (
int
) – The length of the distortion filter allowedzero_mean (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zerokwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.audio import SignalDistortionRatio >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> sdr = SignalDistortionRatio() >>> sdr(preds, target) tensor(-12.0589) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) tensor(-11.6051)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import SignalDistortionRatio >>> metric = SignalDistortionRatio() >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import SignalDistortionRatio >>> metric = SignalDistortionRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(8000), torch.rand(8000))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]¶
Calculate Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.
- Parameters:
preds (
Tensor
) – float tensor with shape(...,time)
target (
Tensor
) – float tensor with shape(...,time)
use_cg_iter (
Optional
[int
]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires thatfast-bss-eval
is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length (
int
) – The length of the distortion filter allowedzero_mean (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero
- Return type:
- Returns:
Float tensor with shape
(...,)
of SDR values per sample- Raises:
RuntimeError – If
preds
andtarget
does not have the same shape
Example
>>> import torch >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> target = torch.randn(8000) >>> signal_distortion_ratio(preds, target) tensor(-12.0589) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio) >>> best_metric tensor([-11.6375, -11.4358, -11.7148, -11.6325]) >>> best_perm tensor([[1, 0], [0, 1], [1, 0], [0, 1]])
Signal-to-Noise Ratio (SNR)¶
Module Interface¶
- class torchmetrics.audio.SignalNoiseRatio(zero_mean=False, **kwargs)[source]¶
Calculate Signal-to-noise ratio (SNR) meric for evaluating quality of audio.
\[\text{SNR} = \frac{P_{signal}}{P_{noise}}\]where \(P\) denotes the power of each signal. The SNR metric compares the level of the desired signal to the level of background noise. Therefore, a high value of SNR means that the audio is clear.
As input to forward and update the metric accepts the following input
preds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
snr
(Tensor
): float scalar tensor with average SNR value over samples
- Parameters:
zero_mean (
bool
) – if to zero mean target and preds or notkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
TypeError – if target and preds have a different shape
Example
>>> from torch import tensor >>> from torchmetrics.audio import SignalNoiseRatio >>> target = tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) >>> snr = SignalNoiseRatio() >>> snr(preds, target) tensor(16.1805)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import SignalNoiseRatio >>> metric = SignalNoiseRatio() >>> metric.update(torch.rand(4), torch.rand(4)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import SignalNoiseRatio >>> metric = SignalNoiseRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(4), torch.rand(4))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.signal_noise_ratio(preds, target, zero_mean=False)[source]¶
Calculate Signal-to-noise ratio (SNR) meric for evaluating quality of audio.
\[\text{SNR} = \frac{P_{signal}}{P_{noise}}\]where \(P\) denotes the power of each signal. The SNR metric compares the level of the desired signal to the level of background noise. Therefore, a high value of SNR means that the audio is clear.
- Parameters:
- Return type:
- Returns:
Float tensor with shape
(...,)
of SNR values per sample- Raises:
RuntimeError – If
preds
andtarget
does not have the same shape
Example
>>> from torchmetrics.functional.audio import signal_noise_ratio >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) >>> signal_noise_ratio(preds, target) tensor(16.1805)
Source Aggregated Signal-to-Distortion Ratio (SA-SDR)¶
Module Interface¶
- class torchmetrics.audio.sdr.SourceAggregatedSignalDistortionRatio(scale_invariant=True, zero_mean=False, **kwargs)[source]¶
Source-aggregated signal-to-distortion ratio (SA-SDR).
The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(..., spk, time)
target
(Tensor
): float tensor with shape(..., spk, time)
As output of forward and compute the metric returns the following output
sa_sdr
(Tensor
): float scalar tensor with average SA-SDR value over samples
- Parameters:
preds – float tensor with shape
(..., spk, time)
target – float tensor with shape
(..., spk, time)
scale_invariant (
bool
) – if True, scale the targets of different speakers with the same alphazero_mean (
bool
) – If to zero mean target and preds or notkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio >>> g = torch.manual_seed(1) >>> preds = torch.randn(2, 8000) # [..., spk, time] >>> target = torch.randn(2, 8000) >>> sasdr = SourceAggregatedSignalDistortionRatio() >>> sasdr(preds, target) tensor(-41.6579) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio, ... mode="permutation-wise", eval_func="max") >>> pit(preds, target) tensor(-41.2790)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio >>> metric = SourceAggregatedSignalDistortionRatio() >>> metric.update(torch.rand(2,8000), torch.rand(2,8000)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio >>> metric = SourceAggregatedSignalDistortionRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(2,8000), torch.rand(2,8000))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.sdr.source_aggregated_signal_distortion_ratio(preds, target, scale_invariant=True, zero_mean=False)[source]¶
Source-aggregated signal-to-distortion ratio (SA-SDR).
The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where one-speaker and multiple-speaker scenes coexist.
- Parameters:
- Return type:
- Returns:
SA-SDR with shape
(...)
Example
>>> import torch >>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio >>> g = torch.manual_seed(1) >>> preds = torch.randn(2, 8000) # [..., spk, time] >>> target = torch.randn(2, 8000) >>> source_aggregated_signal_distortion_ratio(preds, target) tensor(-41.6579) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training >>> preds = torch.randn(4, 2, 8000) # [batch, spk, time] >>> target = torch.randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, ... source_aggregated_signal_distortion_ratio, mode="permutation-wise") >>> best_metric tensor([-37.9511, -41.9124, -42.7369, -42.5155]) >>> best_perm tensor([[1, 0], [1, 0], [0, 1], [1, 0]])
Speech-to-Reverberation Modulation Energy Ratio (SRMR)¶
Module Interface¶
- class torchmetrics.audio.srmr.SpeechReverberationModulationEnergyRatio(fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=None, norm=False, fast=False, **kwargs)[source]¶
Calculate Speech-to-Reverberation Modulation Energy Ratio (SRMR).
SRMR is a non-intrusive metric for speech quality and intelligibility based on a modulation spectral representation of the speech signal. This code is translated from SRMRToolbox and SRMRpy.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
srmr
(Tensor
): float scaler tensor
Note
using this metrics requires you to have
gammatone
andtorchaudio
installed. Either install aspip install torchmetrics[audio]
orpip install torchaudio
andpip install git+https://github.com/detly/gammatone
.Note
This implementation is experimental, and might not be consistent with the matlab implementation SRMRToolbox, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have a relatively small inconsistence.
- Parameters:
fs (
int
) – the sampling raten_cochlear_filters (
int
) – Number of filters in the acoustic filterbanklow_freq (
float
) – determines the frequency cutoff for the corresponding gammatone filterbank.min_cf (
float
) – Center frequency in Hz of the first modulation filter.max_cf (
Optional
[float
]) – Center frequency in Hz of the last modulation filter. If None is given, then 30 Hz will be used for norm==False, otherwise 128 Hz will be used.norm (
bool
) – Use modulation spectrum energy normalizationfast (
bool
) – Use the faster version based on the gammatonegram. Note: this argument is inherited from SRMRpy. As the translated code is based to pytorch, setting fast=True may slow down the speed for calculating this metric on GPU.
- Raises:
ModuleNotFoundError – If
gammatone
ortorchaudio
package is not installed
Example
>>> import torch >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> srmr = SpeechReverberationModulationEnergyRatio(8000) >>> srmr(preds) tensor(0.3354)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio >>> metric = SpeechReverberationModulationEnergyRatio(8000) >>> metric.update(torch.rand(8000)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import SpeechReverberationModulationEnergyRatio >>> metric = SpeechReverberationModulationEnergyRatio(8000) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(8000))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.srmr.speech_reverberation_modulation_energy_ratio(preds, fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=None, norm=False, fast=False)[source]¶
Calculate Speech-to-Reverberation Modulation Energy Ratio (SRMR).
SRMR is a non-intrusive metric for speech quality and intelligibility based on a modulation spectral representation of the speech signal. This code is translated from SRMRToolbox and SRMRpy.
- Parameters:
preds (
Tensor
) – shape(..., time)
fs (
int
) – the sampling raten_cochlear_filters (
int
) – Number of filters in the acoustic filterbanklow_freq (
float
) – determines the frequency cutoff for the corresponding gammatone filterbank.min_cf (
float
) – Center frequency in Hz of the first modulation filter.max_cf (
Optional
[float
]) – Center frequency in Hz of the last modulation filter. If None is given, then 30 Hz will be used for norm==False, otherwise 128 Hz will be used.norm (
bool
) – Use modulation spectrum energy normalizationfast (
bool
) – Use the faster version based on the gammatonegram. Note: this argument is inherited from SRMRpy. As the translated code is based to pytorch, setting fast=True may slow down the speed for calculating this metric on GPU.
Note
using this metrics requires you to have
gammatone
andtorchaudio
installed. Either install aspip install torchmetrics[audio]
orpip install torchaudio
andpip install git+https://github.com/detly/gammatone
.Note
This implementation is experimental, and might not be consistent with the matlab implementation SRMRToolbox, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have a relatively small inconsistence.
- Return type:
- Returns:
Scalar tensor with srmr value with shape
(...)
- Raises:
ModuleNotFoundError – If
gammatone
ortorchaudio
package is not installed
Example
>>> import torch >>> from torchmetrics.functional.audio import speech_reverberation_modulation_energy_ratio >>> g = torch.manual_seed(1) >>> preds = torch.randn(8000) >>> speech_reverberation_modulation_energy_ratio(preds, 8000) tensor([0.3354], dtype=torch.float64)
Accuracy¶
Module Interface¶
- class torchmetrics.Accuracy(**kwargs)[source]¶
Compute Accuracy.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryAccuracy
,MulticlassAccuracy
andMultilabelAccuracy
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 2, 3]) >>> preds = tensor([0, 2, 1, 3]) >>> accuracy = Accuracy(task="multiclass", num_classes=4) >>> accuracy(preds, target) tensor(0.5000)
>>> target = tensor([0, 1, 2]) >>> preds = tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) >>> accuracy = Accuracy(task="multiclass", num_classes=3, top_k=2) >>> accuracy(preds, target) tensor(0.6667)
BinaryAccuracy¶
- class torchmetrics.classification.BinaryAccuracy(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Accuracy for binary tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:ba
(Tensor
): Ifmultidim_average
is set toglobal
, metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryAccuracy >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryAccuracy() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryAccuracy >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryAccuracy() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryAccuracy >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryAccuracy(multidim_average='samplewise') >>> metric(preds, target) tensor([0.3333, 0.1667])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryAccuracy >>> metric = BinaryAccuracy() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryAccuracy >>> metric = BinaryAccuracy() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassAccuracy¶
- class torchmetrics.classification.MulticlassAccuracy(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Accuracy for multiclass tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output to
forward
andcompute
the metric returns the following output:mca
(Tensor
): A tensor with the accuracy score whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassAccuracy >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassAccuracy(num_classes=3) >>> metric(preds, target) tensor(0.8333) >>> mca = MulticlassAccuracy(num_classes=3, average=None) >>> mca(preds, target) tensor([0.5000, 1.0000, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassAccuracy >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassAccuracy(num_classes=3) >>> metric(preds, target) tensor(0.8333) >>> mca = MulticlassAccuracy(num_classes=3, average=None) >>> mca(preds, target) tensor([0.5000, 1.0000, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassAccuracy >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.5000, 0.2778]) >>> mca = MulticlassAccuracy(num_classes=3, multidim_average='samplewise', average=None) >>> mca(preds, target) tensor([[1.0000, 0.0000, 0.5000], [0.0000, 0.3333, 0.5000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = MulticlassAccuracy(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = MulticlassAccuracy(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelAccuracy¶
- class torchmetrics.classification.MultilabelAccuracy(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Accuracy for multilabel tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
As output to
forward
andcompute
the metric returns the following output:mla
(Tensor
): A tensor with the accuracy score whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelAccuracy >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelAccuracy(num_labels=3) >>> metric(preds, target) tensor(0.6667) >>> mla = MultilabelAccuracy(num_labels=3, average=None) >>> mla(preds, target) tensor([1.0000, 0.5000, 0.5000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelAccuracy >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelAccuracy(num_labels=3) >>> metric(preds, target) tensor(0.6667) >>> mla = MultilabelAccuracy(num_labels=3, average=None) >>> mla(preds, target) tensor([1.0000, 0.5000, 0.5000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelAccuracy >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor( ... [ ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], ... ] ... ) >>> mla = MultilabelAccuracy(num_labels=3, multidim_average='samplewise') >>> mla(preds, target) tensor([0.3333, 0.1667]) >>> mla = MultilabelAccuracy(num_labels=3, multidim_average='samplewise', average=None) >>> mla(preds, target) tensor([[0.5000, 0.5000, 0.0000], [0.0000, 0.0000, 0.5000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelAccuracy >>> metric = MultilabelAccuracy(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelAccuracy >>> metric = MultilabelAccuracy(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.classification.accuracy(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute Accuracy. :rtype:
Tensor
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_accuracy()
,multiclass_accuracy()
andmultilabel_accuracy()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 2, 3]) >>> preds = tensor([0, 2, 1, 3]) >>> accuracy(preds, target, task="multiclass", num_classes=4) tensor(0.5000)
>>> target = tensor([0, 1, 2]) >>> preds = tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) >>> accuracy(preds, target, task="multiclass", num_classes=3, top_k=2) tensor(0.6667)
binary_accuracy¶
- torchmetrics.functional.classification.binary_accuracy(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Accuracy for binary tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_accuracy >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_accuracy(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_accuracy >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_accuracy(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_accuracy >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_accuracy(preds, target, multidim_average='samplewise') tensor([0.3333, 0.1667])
multiclass_accuracy¶
- torchmetrics.functional.classification.multiclass_accuracy(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Accuracy for multiclass tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_accuracy >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_accuracy(preds, target, num_classes=3) tensor(0.8333) >>> multiclass_accuracy(preds, target, num_classes=3, average=None) tensor([0.5000, 1.0000, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_accuracy >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_accuracy(preds, target, num_classes=3) tensor(0.8333) >>> multiclass_accuracy(preds, target, num_classes=3, average=None) tensor([0.5000, 1.0000, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_accuracy >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise') tensor([0.5000, 0.2778]) >>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[1.0000, 0.0000, 0.5000], [0.0000, 0.3333, 0.5000]])
multilabel_accuracy¶
- torchmetrics.functional.classification.multilabel_accuracy(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Accuracy for multilabel tasks.
\[\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_accuracy >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_accuracy(preds, target, num_labels=3) tensor(0.6667) >>> multilabel_accuracy(preds, target, num_labels=3, average=None) tensor([1.0000, 0.5000, 0.5000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_accuracy >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_accuracy(preds, target, num_labels=3) tensor(0.6667) >>> multilabel_accuracy(preds, target, num_labels=3, average=None) tensor([1.0000, 0.5000, 0.5000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_accuracy >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_accuracy(preds, target, num_labels=3, multidim_average='samplewise') tensor([0.3333, 0.1667]) >>> multilabel_accuracy(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[0.5000, 0.5000, 0.0000], [0.0000, 0.0000, 0.5000]])
AUROC¶
Module Interface¶
- class torchmetrics.AUROC(**kwargs)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC).
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryAUROC
,MulticlassAUROC
andMultilabelAUROC
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = tensor([0, 0, 1, 1, 1]) >>> auroc = AUROC(task="binary") >>> auroc(preds, target) tensor(0.5000)
>>> preds = tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = tensor([0, 1, 1, 2, 2]) >>> auroc = AUROC(task="multiclass", num_classes=3) >>> auroc(preds, target) tensor(0.7778)
BinaryAUROC¶
- class torchmetrics.classification.BinaryAUROC(max_fpr=None, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for binary tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
As output to
forward
andcompute
the metric returns the following output:b_auroc
(Tensor
): A single scalar with the auroc score.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
max_fpr (
Optional
[float
]) – If notNone
, calculates standardized partial AUC over the range[0, max_fpr]
.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryAUROC >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryAUROC(thresholds=None) >>> metric(preds, target) tensor(0.5000) >>> b_auroc = BinaryAUROC(thresholds=5) >>> b_auroc(preds, target) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import BinaryAUROC >>> metric = BinaryAUROC() >>> metric.update(torch.rand(20,), torch.randint(2, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import BinaryAUROC >>> metric = BinaryAUROC() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassAUROC¶
- class torchmetrics.classification.MulticlassAUROC(num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multiclass tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric. By default the reported metric is then the average over all classes, but this behavior can be changed by setting the
average
argument.As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
As output to
forward
andcompute
the metric returns the following output:mc_auroc
(Tensor
): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”macro”|”weighted” then a single scalar is returned.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over classes. Should be one of the following:
macro
: Calculate score for each class and average themweighted
: calculates score for each class and computes weighted average using their support"none"
orNone
: calculates score for each class and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassAUROC >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassAUROC(num_classes=5, average="macro", thresholds=None) >>> metric(preds, target) tensor(0.5333) >>> mc_auroc = MulticlassAUROC(num_classes=5, average=None, thresholds=None) >>> mc_auroc(preds, target) tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) >>> mc_auroc = MulticlassAUROC(num_classes=5, average="macro", thresholds=5) >>> mc_auroc(preds, target) tensor(0.5333) >>> mc_auroc = MulticlassAUROC(num_classes=5, average=None, thresholds=5) >>> mc_auroc(preds, target) tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import MulticlassAUROC >>> metric = MulticlassAUROC(num_classes=3) >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import MulticlassAUROC >>> metric = MulticlassAUROC(num_classes=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelAUROC¶
- class torchmetrics.classification.MultilabelAUROC(num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multilabel tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
As output to
forward
andcompute
the metric returns the following output:ml_auroc
(Tensor
): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum score over all labelsmacro
: Calculate score for each label and average themweighted
: calculates score for each label and computes weighted average using their support"none"
orNone
: calculates score for each label and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelAUROC >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> ml_auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) >>> ml_auroc(preds, target) tensor(0.6528) >>> ml_auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=None) >>> ml_auroc(preds, target) tensor([0.6250, 0.5000, 0.8333]) >>> ml_auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=5) >>> ml_auroc(preds, target) tensor(0.6528) >>> ml_auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=5) >>> ml_auroc(preds, target) tensor([0.6250, 0.5000, 0.8333])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import MultilabelAUROC >>> metric = MultilabelAUROC(num_labels=3) >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import MultilabelAUROC >>> metric = MultilabelAUROC(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.auroc(preds, target, task, thresholds=None, num_classes=None, num_labels=None, average='macro', max_fpr=None, ignore_index=None, validate_args=True)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC).
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_auroc()
,multiclass_auroc()
andmultilabel_auroc()
for the specific details of each argument influence and examples.- Legacy Example:
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> auroc(preds, target, task='binary') tensor(0.5000)
>>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) >>> auroc(preds, target, task='multiclass', num_classes=3) tensor(0.7778)
binary_auroc¶
- torchmetrics.functional.classification.binary_auroc(preds, target, max_fpr=None, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for binary tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsmax_fpr (
Optional
[float
]) – If notNone
, calculates standardized partial AUC over the range[0, max_fpr]
.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A single scalar with the auroc score
Example
>>> from torchmetrics.functional.classification import binary_auroc >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_auroc(preds, target, thresholds=None) tensor(0.5000) >>> binary_auroc(preds, target, thresholds=5) tensor(0.5000)
multiclass_auroc¶
- torchmetrics.functional.classification.multiclass_auroc(preds, target, num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multiclass tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over classes. Should be one of the following:
macro
: Calculate score for each class and average themweighted
: calculates score for each class and computes weighted average using their support"none"
orNone
: calculates score for each class and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”macro”|”weighted” then a single scalar is returned.
Example
>>> from torchmetrics.functional.classification import multiclass_auroc >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=None) tensor(0.5333) >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=None) tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000]) >>> multiclass_auroc(preds, target, num_classes=5, average="macro", thresholds=5) tensor(0.5333) >>> multiclass_auroc(preds, target, num_classes=5, average=None, thresholds=5) tensor([1.0000, 1.0000, 0.3333, 0.3333, 0.0000])
multilabel_auroc¶
- torchmetrics.functional.classification.multilabel_auroc(preds, target, num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) for multilabel tasks.
The AUROC score summarizes the ROC curve into an single number that describes the performance of a model for multiple thresholds at the same time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum score over all labelsmacro
: Calculate score for each label and average themweighted
: calculates score for each label and computes weighted average using their support"none"
orNone
: calculates score for each label and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with auroc score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.
Example
>>> from torchmetrics.functional.classification import multilabel_auroc >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None) tensor(0.6528) >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None) tensor([0.6250, 0.5000, 0.8333]) >>> multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=5) tensor(0.6528) >>> multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=5) tensor([0.6250, 0.5000, 0.8333])
Average Precision¶
Module Interface¶
- class torchmetrics.AveragePrecision(**kwargs)[source]¶
Compute the average precision (AP) score.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryAveragePrecision
,MulticlassAveragePrecision
andMultilabelAveragePrecision
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> pred = tensor([0, 0.1, 0.8, 0.4]) >>> target = tensor([0, 1, 1, 1]) >>> average_precision = AveragePrecision(task="binary") >>> average_precision(pred, target) tensor(1.)
>>> pred = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> average_precision = AveragePrecision(task="multiclass", num_classes=5, average=None) >>> average_precision(pred, target) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
BinaryAveragePrecision¶
- class torchmetrics.classification.BinaryAveragePrecision(thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the average precision (AP) score for binary tasks.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
As output to
forward
andcompute
the metric returns the following output:bap
(Tensor
): A single scalar with the average precision score
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryAveragePrecision >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryAveragePrecision(thresholds=None) >>> metric(preds, target) tensor(0.5833) >>> bap = BinaryAveragePrecision(thresholds=5) >>> bap(preds, target) tensor(0.6667)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import BinaryAveragePrecision >>> metric = BinaryAveragePrecision() >>> metric.update(torch.rand(20,), torch.randint(2, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import BinaryAveragePrecision >>> metric = BinaryAveragePrecision() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassAveragePrecision¶
- class torchmetrics.classification.MulticlassAveragePrecision(num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the average precision (AP) score for multiclass tasks.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric. By default the reported metric is then the average over all classes, but this behavior can be changed by setting the
average
argument.As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
As output to
forward
andcompute
the metric returns the following output:mcap
(Tensor
): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”macro”|”weighted” then a single scalar is returned.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over classes. Should be one of the following:
macro
: Calculate score for each class and average themweighted
: calculates score for each class and computes weighted average using their support"none"
orNone
: calculates score for each class and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassAveragePrecision >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=None) >>> metric(preds, target) tensor(0.6250) >>> mcap = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=None) >>> mcap(preds, target) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) >>> mcap = MulticlassAveragePrecision(num_classes=5, average="macro", thresholds=5) >>> mcap(preds, target) tensor(0.5000) >>> mcap = MulticlassAveragePrecision(num_classes=5, average=None, thresholds=5) >>> mcap(preds, target) tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import MulticlassAveragePrecision >>> metric = MulticlassAveragePrecision(num_classes=3) >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import MulticlassAveragePrecision >>> metric = MulticlassAveragePrecision(num_classes=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelAveragePrecision¶
- class torchmetrics.classification.MultilabelAveragePrecision(num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the average precision (AP) score for multilabel tasks.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum_{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
As output to
forward
andcompute
the metric returns the following output:mlap
(Tensor
): If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum score over all labelsmacro
: Calculate score for each label and average themweighted
: calculates score for each label and computes weighted average using their support"none"
orNone
: calculates score for each label and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelAveragePrecision >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=None) >>> metric(preds, target) tensor(0.7500) >>> mlap = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=None) >>> mlap(preds, target) tensor([0.7500, 0.5833, 0.9167]) >>> mlap = MultilabelAveragePrecision(num_labels=3, average="macro", thresholds=5) >>> mlap(preds, target) tensor(0.7778) >>> mlap = MultilabelAveragePrecision(num_labels=3, average=None, thresholds=5) >>> mlap(preds, target) tensor([0.7500, 0.6667, 0.9167])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single >>> import torch >>> from torchmetrics.classification import MultilabelAveragePrecision >>> metric = MultilabelAveragePrecision(num_labels=3) >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.classification import MultilabelAveragePrecision >>> metric = MultilabelAveragePrecision(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.average_precision(preds, target, task, thresholds=None, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True)[source]¶
Compute the average precision (AP) score.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight: :rtype:
Optional
[Tensor
]\[AP = \sum{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_average_precision()
,multiclass_average_precision()
andmultilabel_average_precision()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torchmetrics.functional.classification import average_precision >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) >>> target = torch.tensor([0, 1, 1, 1]) >>> average_precision(pred, target, task="binary") tensor(1.)
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan])
binary_average_precision¶
- torchmetrics.functional.classification.binary_average_precision(preds, target, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the average precision (AP) score for binary tasks.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A single scalar with the average precision score
Example
>>> from torchmetrics.functional.classification import binary_average_precision >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_average_precision(preds, target, thresholds=None) tensor(0.5833) >>> binary_average_precision(preds, target, thresholds=5) tensor(0.6667)
multiclass_average_precision¶
- torchmetrics.functional.classification.multiclass_average_precision(preds, target, num_classes, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the average precision (AP) score for multiclass tasks.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over classes. Should be one of the following:
macro
: Calculate score for each class and average themweighted
: calculates score for each class and computes weighted average using their support"none"
orNone
: calculates score for each class and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”macro”|”weighted” then a single scalar is returned.
Example
>>> from torchmetrics.functional.classification import multiclass_average_precision >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=None) tensor(0.6250) >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=None) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) >>> multiclass_average_precision(preds, target, num_classes=5, average="macro", thresholds=5) tensor(0.5000) >>> multiclass_average_precision(preds, target, num_classes=5, average=None, thresholds=5) tensor([1.0000, 1.0000, 0.2500, 0.2500, -0.0000])
multilabel_average_precision¶
- torchmetrics.functional.classification.multilabel_average_precision(preds, target, num_labels, average='macro', thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the average precision (AP) score for multilabel tasks.
The AP score summarizes a precision-recall curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold as weight:
\[AP = \sum{n} (R_n - R_{n-1}) P_n\]where \(P_n, R_n\) is the respective precision and recall at threshold index \(n\). This value is equivalent to the area under the precision-recall curve (AUPRC).
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum score over all labelsmacro
: Calculate score for each label and average themweighted
: calculates score for each label and computes weighted average using their support"none"
orNone
: calculates score for each label and applies no reduction
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If average=None|”none” then a 1d tensor of shape (n_classes, ) will be returned with AP score per class. If average=”micro|macro”|”weighted” then a single scalar is returned.
Example
>>> from torchmetrics.functional.classification import multilabel_average_precision >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=None) tensor(0.7500) >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=None) tensor([0.7500, 0.5833, 0.9167]) >>> multilabel_average_precision(preds, target, num_labels=3, average="macro", thresholds=5) tensor(0.7778) >>> multilabel_average_precision(preds, target, num_labels=3, average=None, thresholds=5) tensor([0.7500, 0.6667, 0.9167])
Calibration Error¶
Module Interface¶
- class torchmetrics.CalibrationError(**kwargs)[source]
-
The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.
\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofBinaryCalibrationError
andMulticlassCalibrationError
for the specific details of each argument influence and examples.
BinaryCalibrationError¶
- class torchmetrics.classification.BinaryCalibrationError(n_bins=15, norm='l1', ignore_index=None, validate_args=True, **kwargs)[source]¶
Top-label Calibration Error for binary tasks.
The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.
\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
As output to
forward
andcompute
the metric returns the following output:bce
(Tensor
): A scalar tensor containing the calibration error
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
n_bins (
int
) – Number of bins to use when computing the metric.norm (
Literal
['l1'
,'l2'
,'max'
]) – Norm used to compare empirical and expected probability bins.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryCalibrationError >>> preds = tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = tensor([0, 0, 1, 1, 1]) >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') >>> metric(preds, target) tensor(0.2900) >>> bce = BinaryCalibrationError(n_bins=2, norm='l2') >>> bce(preds, target) tensor(0.2918) >>> bce = BinaryCalibrationError(n_bins=2, norm='max') >>> bce(preds, target) tensor(0.3167)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryCalibrationError >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryCalibrationError >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassCalibrationError¶
- class torchmetrics.classification.MulticlassCalibrationError(num_classes, n_bins=15, norm='l1', ignore_index=None, validate_args=True, **kwargs)[source]¶
Top-label Calibration Error for multiclass tasks.
The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.
\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mcce
(Tensor
): A scalar tensor containing the calibration error
- Parameters:
num_classes (
int
) – Integer specifing the number of classesn_bins (
int
) – Number of bins to use when computing the metric.norm (
Literal
['l1'
,'l2'
,'max'
]) – Norm used to compare empirical and expected probability bins.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassCalibrationError >>> preds = tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = tensor([0, 1, 2, 0]) >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') >>> metric(preds, target) tensor(0.2000) >>> mcce = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l2') >>> mcce(preds, target) tensor(0.2082) >>> mcce = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='max') >>> mcce(preds, target) tensor(0.2333)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MulticlassCalibrationError >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') >>> metric.update(randn(20,3).softmax(dim=-1), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randn, randint >>> # Example plotting a multiple values >>> from torchmetrics.classification import MulticlassCalibrationError >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') >>> values = [] >>> for _ in range(20): ... values.append(metric(randn(20,3).softmax(dim=-1), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.calibration_error(preds, target, task, n_bins=15, norm='l1', num_classes=None, ignore_index=None, validate_args=True)[source]¶
-
The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric. :rtype:
Tensor
\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofbinary_calibration_error()
andmulticlass_calibration_error()
for the specific details of each argument influence and examples.
binary_calibration_error¶
- torchmetrics.functional.classification.binary_calibration_error(preds, target, n_bins=15, norm='l1', ignore_index=None, validate_args=True)[source]¶
Top-label Calibration Error for binary tasks.
The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.
\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsn_bins (
int
) – Number of bins to use when computing the metric.norm (
Literal
['l1'
,'l2'
,'max'
]) – Norm used to compare empirical and expected probability bins.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import binary_calibration_error >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> binary_calibration_error(preds, target, n_bins=2, norm='l1') tensor(0.2900) >>> binary_calibration_error(preds, target, n_bins=2, norm='l2') tensor(0.2918) >>> binary_calibration_error(preds, target, n_bins=2, norm='max') tensor(0.3167)
multiclass_calibration_error¶
- torchmetrics.functional.classification.multiclass_calibration_error(preds, target, num_classes, n_bins=15, norm='l1', ignore_index=None, validate_args=True)[source]¶
Top-label Calibration Error for multiclass tasks.
The expected calibration error can be used to quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric.
\[\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}\]\[\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}\]\[\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}\]Where \(p_i\) is the top-1 prediction accuracy in bin \(i\), \(c_i\) is the average confidence of predictions in bin \(i\), and \(b_i\) is the fraction of data points in bin \(i\). Bins are constructed in an uniform way in the [0,1] range.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesn_bins (
int
) – Number of bins to use when computing the metric.norm (
Literal
['l1'
,'l2'
,'max'
]) – Norm used to compare empirical and expected probability bins.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import multiclass_calibration_error >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = torch.tensor([0, 1, 2, 0]) >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l1') tensor(0.2000) >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l2') tensor(0.2082) >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='max') tensor(0.2333)
Cohen Kappa¶
Module Interface¶
- class torchmetrics.CohenKappa(**kwargs)[source]¶
Calculate Cohen’s kappa score that measures inter-annotator agreement.
\[\kappa = (p_o - p_e) / (1 - p_e)\]where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofBinaryCohenKappa
andMulticlassCohenKappa
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> cohenkappa = CohenKappa(task="multiclass", num_classes=2) >>> cohenkappa(preds, target) tensor(0.5000)
BinaryCohenKappa¶
- class torchmetrics.classification.BinaryCohenKappa(threshold=0.5, ignore_index=None, weights=None, validate_args=True, **kwargs)[source]¶
Calculate Cohen’s kappa score that measures inter-annotator agreement for binary tasks.
\[\kappa = (p_o - p_e) / (1 - p_e)\]where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:bck
(Tensor
): A tensor containing cohen kappa score
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationweights (
Optional
[Literal
['linear'
,'quadratic'
,'none'
]]) –Weighting type to calculate the score. Choose from:
None
or'none'
: no weighting'linear'
: linear weighting'quadratic'
: quadratic weighting
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryCohenKappa >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> metric = BinaryCohenKappa() >>> metric(preds, target) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryCohenKappa >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryCohenKappa() >>> metric(preds, target) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryCohenKappa >>> metric = BinaryCohenKappa() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryCohenKappa >>> metric = BinaryCohenKappa() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassCohenKappa¶
- class torchmetrics.classification.MulticlassCohenKappa(num_classes, ignore_index=None, weights=None, validate_args=True, **kwargs)[source]¶
Calculate Cohen’s kappa score that measures inter-annotator agreement for multiclass tasks.
\[\kappa = (p_o - p_e) / (1 - p_e)\]where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Either an int tensor of shape(N, ...)` or float tensor of shape ``(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mcck
(Tensor
): A tensor containing cohen kappa score
- Parameters:
num_classes (
int
) – Integer specifing the number of classesignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationweights (
Optional
[Literal
['linear'
,'quadratic'
,'none'
]]) –Weighting type to calculate the score. Choose from:
None
or'none'
: no weighting'linear'
: linear weighting'quadratic'
: quadratic weighting
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassCohenKappa >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassCohenKappa(num_classes=3) >>> metric(preds, target) tensor(0.6364)
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassCohenKappa >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassCohenKappa(num_classes=3) >>> metric(preds, target) tensor(0.6364)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MulticlassCohenKappa >>> metric = MulticlassCohenKappa(num_classes=3) >>> metric.update(randn(20,3).softmax(dim=-1), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randn, randint >>> # Example plotting a multiple values >>> from torchmetrics.classification import MulticlassCohenKappa >>> metric = MulticlassCohenKappa(num_classes=3) >>> values = [] >>> for _ in range(20): ... values.append(metric(randn(20,3).softmax(dim=-1), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
cohen_kappa¶
- torchmetrics.functional.cohen_kappa(preds, target, task, threshold=0.5, num_classes=None, weights=None, ignore_index=None, validate_args=True)[source]¶
Calculate Cohen’s kappa score that measures inter-annotator agreement. It is defined as. :rtype:
Tensor
\[\kappa = (p_o - p_e) / (1 - p_e)\]where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofbinary_cohen_kappa()
andmulticlass_cohen_kappa()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> cohen_kappa(preds, target, task="multiclass", num_classes=2) tensor(0.5000)
binary_cohen_kappa¶
- torchmetrics.functional.classification.binary_cohen_kappa(preds, target, threshold=0.5, weights=None, ignore_index=None, validate_args=True)[source]¶
Calculate Cohen’s kappa score that measures inter-annotator agreement for binary tasks.
\[\kappa = (p_o - p_e) / (1 - p_e)\]where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsweights (
Optional
[Literal
['linear'
,'quadratic'
,'none'
]]) –Weighting type to calculate the score. Choose from:
None
or'none'
: no weighting'linear'
: linear weighting'quadratic'
: quadratic weighting
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_cohen_kappa >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_cohen_kappa(preds, target) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_cohen_kappa >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_cohen_kappa(preds, target) tensor(0.5000)
multiclass_cohen_kappa¶
- torchmetrics.functional.classification.multiclass_cohen_kappa(preds, target, num_classes, weights=None, ignore_index=None, validate_args=True)[source]¶
Calculate Cohen’s kappa score that measures inter-annotator agreement for multiclass tasks.
\[\kappa = (p_o - p_e) / (1 - p_e)\]where \(p_o\) is the empirical probability of agreement and \(p_e\) is the expected agreement when both annotators assign labels randomly. Note that \(p_e\) is estimated using a per-annotator empirical prior over the class labels.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesweights (
Optional
[Literal
['linear'
,'quadratic'
,'none'
]]) –Weighting type to calculate the score. Choose from:
None
or'none'
: no weighting'linear'
: linear weighting'quadratic'
: quadratic weighting
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_cohen_kappa >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_cohen_kappa(preds, target, num_classes=3) tensor(0.6364)
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_cohen_kappa >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_cohen_kappa(preds, target, num_classes=3) tensor(0.6364)
Confusion Matrix¶
Module Interface¶
- class torchmetrics.ConfusionMatrix(**kwargs)[source]¶
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryConfusionMatrix
,MulticlassConfusionMatrix
andMultilabelConfusionMatrix
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary", num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
BinaryConfusionMatrix¶
- class torchmetrics.classification.BinaryConfusionMatrix(threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]¶
Compute the confusion matrix for binary tasks.
The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.
For binary tasks, the confusion matrix is a 2x2 matrix with the following structure:
\(C_{0, 0}\): True negatives
\(C_{0, 1}\): False positives
\(C_{1, 0}\): False negatives
\(C_{1, 1}\): True positives
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:confusion_matrix
(Tensor
): A tensor containing a(2, 2)
matrix
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> bcm = BinaryConfusionMatrix() >>> bcm(preds, target) tensor([[2, 0], [1, 1]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text (
bool
) – if the value of each cell should be added to the plotlabels (
Optional
[List
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MulticlassConfusionMatrix¶
- class torchmetrics.classification.MulticlassConfusionMatrix(num_classes, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]¶
Compute the confusion matrix for multiclass tasks.
The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.
For multiclass tasks, the confusion matrix is a NxN matrix, where:
\(C_{i, i}\) represents the number of true positives for class \(i\)
\(\sum_{j=1, j\neq i}^N C_{i, j}\) represents the number of false negatives for class \(i\)
\(\sum_{i=1, i\neq j}^N C_{i, j}\) represents the number of false positives for class \(i\)
the sum of the remaining cells in the matrix represents the number of true negatives for class \(i\)
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:confusion_matrix
: [num_classes, num_classes] matrix
- Parameters:
num_classes (
int
) – Integer specifing the number of classesignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text (
bool
) – if the value of each cell should be added to the plotlabels (
Optional
[List
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
MultilabelConfusionMatrix¶
- class torchmetrics.classification.MultilabelConfusionMatrix(num_labels, threshold=0.5, ignore_index=None, normalize=None, validate_args=True, **kwargs)[source]¶
Compute the confusion matrix for multilabel tasks.
The confusion matrix \(C\) is constructed such that \(C_{i, j}\) is equal to the number of observations known to be in class \(i\) but predicted to be in class \(j\). Thus row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels.
For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion for that label. The structure of each 2x2 matrix is as follows:
\(C_{0, 0}\): True negatives
\(C_{0, 1}\): False positives
\(C_{1, 0}\): False negatives
\(C_{1, 1}\): True positives
As input to ‘update’ the metric accepts the following input:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
As output of ‘compute’ the metric returns the following output:
confusion matrix
: [num_labels,2,2] matrix
- Parameters:
num_classes – Integer specifing the number of labels
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) >>> metric(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- plot(val=None, ax=None, add_text=True, labels=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Optional
[Tensor
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axisadd_text (
bool
) – if the value of each cell should be added to the plotlabels (
Optional
[List
[str
]]) – a list of strings, if provided will be added to the plot to indicate the different classes
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> metric = MulticlassConfusionMatrix(num_classes=5) >>> metric.update(randint(5, (20,)), randint(5, (20,))) >>> fig_, ax_ = metric.plot()
Functional Interface¶
confusion_matrix¶
- torchmetrics.functional.confusion_matrix(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_confusion_matrix()
,multiclass_confusion_matrix()
andmultilabel_confusion_matrix()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> from torchmetrics.classification import ConfusionMatrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> confmat = ConfusionMatrix(task="binary") >>> confmat(preds, target) tensor([[2, 0], [1, 1]])
>>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
>>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
binary_confusion_matrix¶
- torchmetrics.functional.classification.binary_confusion_matrix(preds, target, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix for binary tasks.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[2, 2]
tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_confusion_matrix(preds, target) tensor([[2, 0], [1, 1]])
multiclass_confusion_matrix¶
- torchmetrics.functional.classification.multiclass_confusion_matrix(preds, target, num_classes, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix for multiclass tasks.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[num_classes, num_classes]
tensor
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]])
multilabel_confusion_matrix¶
- torchmetrics.functional.classification.multilabel_confusion_matrix(preds, target, num_labels, threshold=0.5, normalize=None, ignore_index=None, validate_args=True)[source]¶
Compute the confusion matrix for multilabel tasks.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsnormalize (
Optional
[Literal
['true'
,'pred'
,'all'
,'none'
]]) –Normalization mode for confusion matrix. Choose from:
None
or'none'
: no normalization (default)'true'
: normalization over the targets (most commonly used)'pred'
: normalization over the predictions'all'
: normalization over the whole matrix
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
A
[num_labels, 2, 2]
tensor
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]])
Coverage Error¶
Module Interface¶
- class torchmetrics.classification.MultilabelCoverageError(num_labels, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Multilabel coverage error.
The score measure how far we need to go through the ranked scores to cover all true labels. The best value is equal to the average number of labels in the target tensor per sample.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mlce
(Tensor
): A tensor containing the multilabel coverage error.
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
Example
>>> from torchmetrics.classification import MultilabelCoverageError >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) >>> mlce = MultilabelCoverageError(num_labels=5) >>> mlce(preds, target) tensor(3.9000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelCoverageError >>> metric = MultilabelCoverageError(num_labels=3) >>> metric.update(rand(20, 3), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelCoverageError >>> metric = MultilabelCoverageError(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(20, 3), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.classification.multilabel_coverage_error(preds, target, num_labels, ignore_index=None, validate_args=True)[source]¶
Compute multilabel coverage error [1].
The score measure how far we need to go through the ranked scores to cover all true labels. The best value is equal to the average number of labels in the target tensor per sample.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import multilabel_coverage_error >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) >>> multilabel_coverage_error(preds, target, num_labels=5) tensor(3.9000)
References
[1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US.
Dice¶
Module Interface¶
- class torchmetrics.Dice(zero_division=0, num_classes=None, threshold=0.5, average='micro', mdmc_average='global', ignore_index=None, top_k=None, multiclass=None, **kwargs)[source]¶
Compute Dice.
\[\text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.
It is recommend set ignore_index to index of background class.
The reduction method (how the precision scores are aggregated) is controlled by the
average
parameter, and additionally by themdmc_average
parameter in the multi-dimensional multi-class case.As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predictions from model (probabilities, logits or labels)target
(Tensor
): Ground truth values
As output to
forward
andcompute
the metric returns the following output:dice
(Tensor
): A tensor containing the dice score.If
average in ['micro', 'macro', 'weighted', 'samples']
, a one-element tensor will be returnedIf
average in ['none', None]
, the shape will be(C,)
, whereC
stands for the number of classes
- Parameters:
num_classes – Number of classes. Necessary for
'macro'
, andNone
average methods.threshold – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
zero_division – The value to use for the score if denominator equals zero.
average –
Defines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.mdmc_average –
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and
average=None
or'none'
, the score for the ignored class will be returned asnan
.top_k – Number of the highest probability or logit score predictions considered finding the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (
None
) will be interpreted as 1 for these inputs. Should be left at default (None
) for all other types of inputs.multiclass – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be.
kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
average
is none of"micro"
,"macro"
,"samples"
,"none"
,None
.ValueError – If
mdmc_average
is not one ofNone
,"samplewise"
,"global"
.ValueError – If
average
is set butnum_classes
is not provided.ValueError – If
num_classes
is set andignore_index
is not in the range[0, num_classes)
.
Example
>>> from torch import tensor >>> from torchmetrics.classification import Dice >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> dice = Dice(average='micro') >>> dice(preds, target) tensor(0.2500)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import randint >>> from torchmetrics.classification import Dice >>> metric = Dice() >>> metric.update(randint(2,(10,)), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import randint >>> from torchmetrics.classification import Dice >>> metric = Dice() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2,(10,)), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.dice(preds, target, zero_division=0, average='micro', mdmc_average='global', threshold=0.5, top_k=None, num_classes=None, multiclass=None, ignore_index=None)[source]¶
Compute Dice.
\[\text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.
It is recommend set ignore_index to index of background class.
The reduction method (how the recall scores are aggregated) is controlled by the
average
parameter, and additionally by themdmc_average
parameter in the multi-dimensional multi-class case.- Parameters:
preds (
Tensor
) – Predictions from model (probabilities, logits or labels)target (
Tensor
) – Ground truth valueszero_division (
int
) – The value to use for the score if denominator equals zeroDefines the reduction that is applied. Should be one of the following:
'micro'
[default]: Calculate the metric globally, across all samples and classes.'macro'
: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class).'weighted'
: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (tp + fn
).'none'
orNone
: Calculate the metric for each class separately, and return the metric for every class.'samples'
: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample).
Note
What is considered a sample in the multi-dimensional multi-class case depends on the value of
mdmc_average
.Note
If
'none'
and a given class doesn’t occur in thepreds
ortarget
, the value for the class will benan
.mdmc_average (
Optional
[str
]) –Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
average
parameter). Should be one of the following:None
[default]: Should be left unchanged if your data is not multi-dimensional multi-class.'samplewise'
: In this case, the statistics are computed separately for each sample on theN
axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes...
as theN
dimension within the sample, and computing the metric for the sample based on that.'global'
: In this case theN
and...
dimensions of the inputs are flattened into a newN_X
sample axis, i.e. the inputs are treated as if they were(N_X, C)
. From here on theaverage
parameter applies as usual.
ignore_index (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, andaverage=None
or'none'
, the score for the ignored class will be returned asnan
.num_classes (
Optional
[int
]) – Number of classes. Necessary for'macro'
,'weighted'
andNone
average methods.threshold (
float
) – Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.Number of the highest probability or logit score predictions considered finding the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (
None
) will be interpreted as 1 for these inputs.Should be left at default (
None
) for all other types of inputs.multiclass (
Optional
[bool
]) – Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be.
- Return type:
- Returns:
The shape of the returned tensor depends on the
average
parameterIf
average in ['micro', 'macro', 'weighted', 'samples']
, a one-element tensor will be returnedIf
average in ['none', None]
, the shape will be(C,)
, whereC
stands for the number of classes
- Raises:
ValueError – If
average
is not one of"micro"
,"macro"
,"weighted"
,"samples"
,"none"
orNone
ValueError – If
mdmc_average
is not one ofNone
,"samplewise"
,"global"
.ValueError – If
average
is set butnum_classes
is not provided.ValueError – If
num_classes
is set andignore_index
is not in the range[0, num_classes)
.
Example
>>> from torchmetrics.functional.classification import dice >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> dice(preds, target, average='micro') tensor(0.2500)
Exact Match¶
Module Interface¶
- class torchmetrics.ExactMatch(**kwargs)[source]¶
Compute Exact match (also known as subset accuracy).
Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.
This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'multiclass'
ormultilabel
. See the documentation ofMulticlassExactMatch
andMultilabelExactMatch
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='global') >>> metric(preds, target) tensor(0.5000)
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = ExactMatch(task="multiclass", num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([1., 0.])
MulticlassExactMatch¶
- class torchmetrics.classification.MulticlassExactMatch(num_classes, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Exact match (also known as subset accuracy) for multiclass tasks.
Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:mcem
(Tensor
): A tensor whose returned shape depends on themultidim_average
argument:If
multidim_average
is set toglobal
the output will be a scalar tensorIf
multidim_average
is set tosamplewise
the output will be a tensor of shape(N,)
- Parameters:
num_classes (
int
) – Integer specifing the number of labelsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (multidim tensors):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassExactMatch >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassExactMatch(num_classes=3, multidim_average='global') >>> metric(preds, target) tensor(0.5000)
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassExactMatch >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassExactMatch(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([1., 0.])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value per class >>> from torch import randint >>> from torchmetrics.classification import MulticlassExactMatch >>> metric = MulticlassExactMatch(num_classes=3) >>> metric.update(randint(3, (20,5)), randint(3, (20,5))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassExactMatch >>> metric = MulticlassExactMatch(num_classes=3) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,5)), randint(3, (20,5)))) >>> fig_, ax_ = metric.plot(values)
MultilabelExactMatch¶
- class torchmetrics.classification.MultilabelExactMatch(num_labels, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Exact match (also known as subset accuracy) for multilabel tasks.
Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor or float tensor of shape(N, C, ..)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
.
As output to
forward
andcompute
the metric returns the following output:mlem
(Tensor
): A tensor whose returned shape depends on themultidim_average
argument:If
multidim_average
is set toglobal
the output will be a scalar tensorIf
multidim_average
is set tosamplewise
the output will be a tensor of shape(N,)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelExactMatch >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelExactMatch(num_labels=3) >>> metric(preds, target) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelExactMatch >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelExactMatch(num_labels=3) >>> metric(preds, target) tensor(0.5000)
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelExactMatch >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelExactMatch(num_labels=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0., 0.])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelExactMatch >>> metric = MultilabelExactMatch(num_labels=3) >>> metric.update(randint(2, (20, 3, 5)), randint(2, (20, 3, 5))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelExactMatch >>> metric = MultilabelExactMatch(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3, 5)), randint(2, (20, 3, 5)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
exact_match¶
- torchmetrics.functional.classification.exact_match(preds, target, task, num_classes=None, num_labels=None, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Exact match (also known as subset accuracy).
Exact Match is a stricter version of accuracy where all classes/labels have to match exactly for the sample to be correctly classified.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'multiclass'
or'multilabel'
. See the documentation ofmulticlass_exact_match()
andmultilabel_exact_match()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='global') tensor(0.5000)
>>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise') tensor([1., 0.])
multiclass_exact_match¶
- torchmetrics.functional.classification.multiclass_exact_match(preds, target, num_classes, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Exact match (also known as subset accuracy) for multiclass tasks.
Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of labelsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
the output will be a scalar tensorIf
multidim_average
is set tosamplewise
the output will be a tensor of shape(N,)
- Return type:
The returned shape depends on the
multidim_average
argument
- Example (multidim tensors):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_exact_match >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='global') tensor(0.5000)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_exact_match >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 1], [2, 1], [0, 2]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_exact_match(preds, target, num_classes=3, multidim_average='samplewise') tensor([1., 0.])
multilabel_exact_match¶
- torchmetrics.functional.classification.multilabel_exact_match(preds, target, num_labels, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Exact match (also known as subset accuracy) for multilabel tasks.
Exact Match is a stricter version of accuracy where all labels have to match exactly for the sample to be correctly classified.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
the output will be a scalar tensorIf
multidim_average
is set tosamplewise
the output will be a tensor of shape(N,)
- Return type:
The returned shape depends on the
multidim_average
argument
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_exact_match >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_exact_match(preds, target, num_labels=3) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_exact_match >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_exact_match(preds, target, num_labels=3) tensor(0.5000)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_exact_match >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_exact_match(preds, target, num_labels=3, multidim_average='samplewise') tensor([0., 0.])
F-1 Score¶
Module Interface¶
- class torchmetrics.F1Score(**kwargs)[source]¶
Compute F-1 score.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryF1Score
,MulticlassF1Score
andMultilabelF1Score
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 2, 0, 1, 2]) >>> preds = tensor([0, 2, 1, 0, 0, 1]) >>> f1 = F1Score(task="multiclass", num_classes=3) >>> f1(preds, target) tensor(0.3333)
BinaryF1Score¶
- class torchmetrics.classification.BinaryF1Score(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute F-1 score for binary tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of 0 is returned.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:bf1s
(Tensor
): A tensor whose returned shape depends on themultidim_average
argument:If
multidim_average
is set toglobal
, the metric returns a scalar value.If
multidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryF1Score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryF1Score() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryF1Score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryF1Score() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryF1Score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryF1Score(multidim_average='samplewise') >>> metric(preds, target) tensor([0.5000, 0.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryF1Score >>> metric = BinaryF1Score() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryF1Score >>> metric = BinaryF1Score() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassF1Score¶
- class torchmetrics.classification.MulticlassF1Score(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute F-1 score for multiclass tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:mcf1s
(Tensor
): A tensor whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
preds – Tensor with predictions
target – Tensor with true labels
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassF1Score >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassF1Score(num_classes=3) >>> metric(preds, target) tensor(0.7778) >>> mcf1s = MulticlassF1Score(num_classes=3, average=None) >>> mcf1s(preds, target) tensor([0.6667, 0.6667, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassF1Score >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassF1Score(num_classes=3) >>> metric(preds, target) tensor(0.7778) >>> mcf1s = MulticlassF1Score(num_classes=3, average=None) >>> mcf1s(preds, target) tensor([0.6667, 0.6667, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassF1Score >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.4333, 0.2667]) >>> mcf1s = MulticlassF1Score(num_classes=3, multidim_average='samplewise', average=None) >>> mcf1s(preds, target) tensor([[0.8000, 0.0000, 0.5000], [0.0000, 0.4000, 0.4000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassF1Score >>> metric = MulticlassF1Score(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassF1Score >>> metric = MulticlassF1Score(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelF1Score¶
- class torchmetrics.classification.MultilabelF1Score(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute F-1 score for multilabel tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
.
As output to
forward
andcompute
the metric returns the following output:mlf1s
(Tensor
): A tensor whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)`
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelF1Score >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelF1Score(num_labels=3) >>> metric(preds, target) tensor(0.5556) >>> mlf1s = MultilabelF1Score(num_labels=3, average=None) >>> mlf1s(preds, target) tensor([1.0000, 0.0000, 0.6667])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelF1Score >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelF1Score(num_labels=3) >>> metric(preds, target) tensor(0.5556) >>> mlf1s = MultilabelF1Score(num_labels=3, average=None) >>> mlf1s(preds, target) tensor([1.0000, 0.0000, 0.6667])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelF1Score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.4444, 0.0000]) >>> mlf1s = MultilabelF1Score(num_labels=3, multidim_average='samplewise', average=None) >>> mlf1s(preds, target) tensor([[0.6667, 0.6667, 0.0000], [0.0000, 0.0000, 0.0000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelF1Score >>> metric = MultilabelF1Score(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelF1Score >>> metric = MultilabelF1Score(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
f1_score¶
- torchmetrics.functional.f1_score(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute F-1 score. :rtype:
Tensor
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_f1_score()
,multiclass_f1_score()
andmultilabel_f1_score()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 2, 0, 1, 2]) >>> preds = tensor([0, 2, 1, 0, 0, 1]) >>> f1_score(preds, target, task="multiclass", num_classes=3) tensor(0.3333)
binary_f1_score¶
- torchmetrics.functional.classification.binary_f1_score(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute F-1 score for binary tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_f1_score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_f1_score(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_f1_score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_f1_score(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_f1_score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_f1_score(preds, target, multidim_average='samplewise') tensor([0.5000, 0.0000])
multiclass_f1_score¶
- torchmetrics.functional.classification.multiclass_f1_score(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute F-1 score for multiclass tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_f1_score >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_f1_score(preds, target, num_classes=3) tensor(0.7778) >>> multiclass_f1_score(preds, target, num_classes=3, average=None) tensor([0.6667, 0.6667, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_f1_score >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_f1_score(preds, target, num_classes=3) tensor(0.7778) >>> multiclass_f1_score(preds, target, num_classes=3, average=None) tensor([0.6667, 0.6667, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_f1_score >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise') tensor([0.4333, 0.2667]) >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[0.8000, 0.0000, 0.5000], [0.0000, 0.4000, 0.4000]])
multilabel_f1_score¶
- torchmetrics.functional.classification.multilabel_f1_score(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute F-1 score for multilabel tasks.
\[F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}}\]Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_f1_score >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_f1_score(preds, target, num_labels=3) tensor(0.5556) >>> multilabel_f1_score(preds, target, num_labels=3, average=None) tensor([1.0000, 0.0000, 0.6667])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_f1_score >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_f1_score(preds, target, num_labels=3) tensor(0.5556) >>> multilabel_f1_score(preds, target, num_labels=3, average=None) tensor([1.0000, 0.0000, 0.6667])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_f1_score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise') tensor([0.4444, 0.0000]) >>> multilabel_f1_score(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[0.6667, 0.6667, 0.0000], [0.0000, 0.0000, 0.0000]])
F-Beta Score¶
Module Interface¶
- class torchmetrics.FBetaScore(**kwargs)[source]¶
Compute F-score metric.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryFBetaScore
,MulticlassFBetaScore
andMultilabelFBetaScore
for the specific details of each argument influence and examples.- Legcy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 2, 0, 1, 2]) >>> preds = tensor([0, 2, 1, 0, 0, 1]) >>> f_beta = FBetaScore(task="multiclass", num_classes=3, beta=0.5) >>> f_beta(preds, target) tensor(0.3333)
BinaryFBetaScore¶
- class torchmetrics.classification.BinaryFBetaScore(beta, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute F-score metric for binary tasks.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered a score of 0 is returned.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:bfbs
(Tensor
): A tensor whose returned shape depends on themultidim_average
argument:If
multidim_average
is set toglobal
the output will be a scalar tensorIf
multidim_average
is set tosamplewise
the output will be a tensor of shape(N,)
consisting of a scalar value per sample.
- Parameters:
beta (
float
) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weightthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryFBetaScore >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryFBetaScore(beta=2.0) >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryFBetaScore >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryFBetaScore(beta=2.0) >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryFBetaScore >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryFBetaScore(beta=2.0, multidim_average='samplewise') >>> metric(preds, target) tensor([0.5882, 0.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryFBetaScore >>> metric = BinaryFBetaScore(beta=2.0) >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryFBetaScore >>> metric = BinaryFBetaScore(beta=2.0) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassFBetaScore¶
- class torchmetrics.classification.MulticlassFBetaScore(beta, num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute F-score metric for multiclass tasks.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:mcfbs
(Tensor
): A tensor whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
beta (
float
) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weightnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassFBetaScore >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3) >>> metric(preds, target) tensor(0.7963) >>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None) >>> mcfbs(preds, target) tensor([0.5556, 0.8333, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassFBetaScore >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3) >>> metric(preds, target) tensor(0.7963) >>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None) >>> mcfbs(preds, target) tensor([0.5556, 0.8333, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassFBetaScore >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.4697, 0.2706]) >>> mcfbs = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise', average=None) >>> mcfbs(preds, target) tensor([[0.9091, 0.0000, 0.5000], [0.0000, 0.3571, 0.4545]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassFBetaScore >>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassFBetaScore >>> metric = MulticlassFBetaScore(num_classes=3, beta=2.0, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelFBetaScore¶
- class torchmetrics.classification.MultilabelFBetaScore(beta, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute F-score metric for multilabel tasks.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0 \wedge \text{TP} + \text{FN} \neq 0\) where \(\text{TP}\), \(\text{FP}\) and \(\text{FN}\) represent the number of true positives, false positives and false negatives respectively. If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
.
As output to
forward
andcompute
the metric returns the following output:mlfbs
(Tensor
): A tensor whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
beta (
float
) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weightnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelFBetaScore >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3) >>> metric(preds, target) tensor(0.6111) >>> mlfbs = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None) >>> mlfbs(preds, target) tensor([1.0000, 0.0000, 0.8333])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelFBetaScore >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3) >>> metric(preds, target) tensor(0.6111) >>> mlfbs = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None) >>> mlfbs(preds, target) tensor([1.0000, 0.0000, 0.8333])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelFBetaScore >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise') >>> metric(preds, target) tensor([0.5556, 0.0000]) >>> mlfbs = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise', average=None) >>> mlfbs(preds, target) tensor([[0.8333, 0.8333, 0.0000], [0.0000, 0.0000, 0.0000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelFBetaScore >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelFBetaScore >>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
fbeta_score¶
- torchmetrics.functional.fbeta_score(preds, target, task, beta=1.0, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute F-score metric. :rtype:
Tensor
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_fbeta_score()
,multiclass_fbeta_score()
andmultilabel_fbeta_score()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 2, 0, 1, 2]) >>> preds = tensor([0, 2, 1, 0, 0, 1]) >>> fbeta_score(preds, target, task="multiclass", num_classes=3, beta=0.5) tensor(0.3333)
binary_fbeta_score¶
- torchmetrics.functional.classification.binary_fbeta_score(preds, target, beta, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute F-score metric for binary tasks.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsbeta (
float
) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weightthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_fbeta_score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_fbeta_score(preds, target, beta=2.0) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_fbeta_score >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_fbeta_score(preds, target, beta=2.0) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_fbeta_score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_fbeta_score(preds, target, beta=2.0, multidim_average='samplewise') tensor([0.5882, 0.0000])
multiclass_fbeta_score¶
- torchmetrics.functional.classification.multiclass_fbeta_score(preds, target, beta, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute F-score metric for multiclass tasks.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsbeta (
float
) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weightnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_fbeta_score >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3) tensor(0.7963) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None) tensor([0.5556, 0.8333, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_fbeta_score >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3) tensor(0.7963) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, average=None) tensor([0.5556, 0.8333, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_fbeta_score >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise') tensor([0.4697, 0.2706]) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise', average=None) tensor([[0.9091, 0.0000, 0.5000], [0.0000, 0.3571, 0.4545]])
multilabel_fbeta_score¶
- torchmetrics.functional.classification.multilabel_fbeta_score(preds, target, beta, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute F-score metric for multilabel tasks.
\[F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}}\]Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsbeta (
float
) – Weighting between precision and recall in calculation. Setting to 1 corresponds to equal weightnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_fbeta_score >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3) tensor(0.6111) >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None) tensor([1.0000, 0.0000, 0.8333])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_fbeta_score >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3) tensor(0.6111) >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3, average=None) tensor([1.0000, 0.0000, 0.8333])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_fbeta_score >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise') tensor([0.5556, 0.0000]) >>> multilabel_fbeta_score(preds, target, num_labels=3, beta=2.0, multidim_average='samplewise', average=None) tensor([[0.8333, 0.8333, 0.0000], [0.0000, 0.0000, 0.0000]])
Group Fairness¶
Module Interface¶
BinaryFairness¶
- class torchmetrics.classification.BinaryFairness(num_groups, task='all', threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]¶
Computes Demographic parity and Equal opportunity ratio for binary classification problems.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.groups
(int tensor):(N, ...)
. The group identifiers should be0, 1, ..., (num_groups - 1)
.target
(int tensor):(N, ...)
.
The additional dimensions are flatted along the batch dimension.
This class computes the ratio between positivity rates and true positives rates for different groups. If more than two groups are present, the disparity between the lowest and highest group is reported. A disparity between positivity rates indicates a potential violation of demographic parity, and between true positive rates indicates a potential violation of equal opportunity.
The lowest rate is divided by the highest, so a lower value means more discrimination against the numerator. In the results this is also indicated as the key of dict is {metric}_{identifier_low_group}_{identifier_high_group}.
- Parameters:
num_groups (
int
) – The number of groups.task (
Literal
['demographic_parity'
,'equal_opportunity'
,'all'
]) – The task to compute. Can be eitherdemographic_parity
orequal_oppotunity
orall
.threshold (
float
) – Threshold for transforming probability to binary {0,1} predictions.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
The metric returns a dict where the key identifies the metric and groups with the lowest and highest true positives rates as follows: {metric}__{identifier_low_group}_{identifier_high_group}. The value is a tensor with the disparity rate.
- Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryFairness >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 1, 0, 1, 0, 1]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> metric = BinaryFairness(2) >>> metric(preds, target, groups) {'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)}
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryFairness >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> metric = BinaryFairness(2) >>> metric(preds, target, groups) {'DP_0_1': tensor(0.), 'EO_0_1': tensor(0.)}
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> _ = torch.manual_seed(42) >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryFairness >>> metric = BinaryFairness(2) >>> metric.update(torch.rand(20), torch.randint(2,(20,)), torch.randint(2,(20,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> _ = torch.manual_seed(42) >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryFairness >>> metric = BinaryFairness(2) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(20), torch.randint(2,(20,)), torch.ones(20).long())) >>> fig_, ax_ = metric.plot(values)
BinaryGroupStatRates¶
- class torchmetrics.classification.BinaryGroupStatRates(num_groups, threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]¶
Computes the true/false positives and true/false negatives rates for binary classification by group.
Related to Type I and Type II errors.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
.groups
(int tensor):(N, ...)
. The group identifiers should be0, 1, ..., (num_groups - 1)
.
The additional dimensions are flatted along the batch dimension.
- Parameters:
num_groups (
int
) – The number of groups.threshold (
float
) – Threshold for transforming probability to binary {0,1} predictions.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
The metric returns a dict with a group identifier as key and a tensor with the tp, fp, tn and fn rates as value.
- Example (preds is int tensor):
>>> from torchmetrics.classification import BinaryGroupStatRates >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 1, 0, 1, 0, 1]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> metric = BinaryGroupStatRates(num_groups=2) >>> metric(preds, target, groups) {'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryGroupStatRates >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> metric = BinaryGroupStatRates(num_groups=2) >>> metric(preds, target, groups) {'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
Functional Interface¶
binary_fairness¶
- torchmetrics.functional.classification.binary_fairness(preds, target, groups, task='all', threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Compute either Demographic parity and Equal opportunity ratio for binary classification problems.
This is done by setting the
task
argument to either'demographic_parity'
,'equal_opportunity'
orall
. See the documentation ofdemographic_parity()
andequal_opportunity()
for the specific details of each argument influence and examples.- Parameters:
preds (
Tensor
) – Tensor with predictions.target (
Tensor
) – Tensor with true labels (not required for demographic_parity).groups (
Tensor
) – Tensor with group identifiers. The group identifiers should be0, 1, ..., (num_groups - 1)
.task (
Literal
['demographic_parity'
,'equal_opportunity'
,'all'
]) – The task to compute. Can be eitherdemographic_parity
orequal_oppotunity
orall
.threshold (
float
) – Threshold for transforming probability to binary {0,1} predictions.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
demographic_parity¶
- torchmetrics.functional.classification.demographic_parity(preds, groups, threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Demographic parity compares the positivity rates between all groups.
If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest positivity rate is divided by the highest, so a lower value means more discrimination against the numerator. In the results this is also indicated as the key of dict is DP_{identifier_low_group}_{identifier_high_group}.
\[\text{DP} = \dfrac{\min_a PR_a}{\max_a PR_a}.\]where \(\text{PR}\) represents the positivity rate for group \(\text{a}\).
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.groups
(int tensor):(N, ...)
. The group identifiers should be0, 1, ..., (num_groups - 1)
.target
(int tensor):(N, ...)
.
The additional dimensions are flatted along the batch dimension.
- Parameters:
preds (
Tensor
) – Tensor with predictions.groups (
Tensor
) – Tensor with group identifiers. The group identifiers should be0, 1, ..., (num_groups - 1)
.threshold (
float
) – Threshold for transforming probability to binary {0,1} predictions.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
The metric returns a dict where the key identifies the group with the lowest and highest positivity rates as follows: DP_{identifier_low_group}_{identifier_high_group}. The value is a tensor with the DP rate.
- Example (preds is int tensor):
>>> from torchmetrics.functional.classification import demographic_parity >>> preds = torch.tensor([0, 1, 0, 1, 0, 1]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> demographic_parity(preds, groups) {'DP_0_1': tensor(0.)}
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import demographic_parity >>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> demographic_parity(preds, groups) {'DP_0_1': tensor(0.)}
equal_opportunity¶
- torchmetrics.functional.classification.equal_opportunity(preds, target, groups, threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Equal opportunity compares the true positive rates between all groups.
If more than two groups are present, the disparity between the lowest and highest group is reported. The lowest true positive rate is divided by the highest, so a lower value means more discrimination against the numerator. In the results this is also indicated as the key of dict is EO_{identifier_low_group}_{identifier_high_group}.
\[\text{DP} = \dfrac{\min_a TPR_a}{\max_a TPR_a}.\]where \(\text{TPR}\) represents the true positives rate for group \(\text{a}\).
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
.groups
(int tensor):(N, ...)
. The group identifiers should be0, 1, ..., (num_groups - 1)
.
The additional dimensions are flatted along the batch dimension.
- Parameters:
preds (
Tensor
) – Tensor with predictions.target (
Tensor
) – Tensor with true labels.groups (
Tensor
) – Tensor with group identifiers. The group identifiers should be0, 1, ..., (num_groups - 1)
.threshold (
float
) – Threshold for transforming probability to binary {0,1} predictions.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
The metric returns a dict where the key identifies the group with the lowest and highest true positives rates as follows: EO_{identifier_low_group}_{identifier_high_group}. The value is a tensor with the EO rate.
- Example (preds is int tensor):
>>> from torchmetrics.functional.classification import equal_opportunity >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 1, 0, 1, 0, 1]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> equal_opportunity(preds, target, groups) {'EO_0_1': tensor(0.)}
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import equal_opportunity >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> equal_opportunity(preds, target, groups) {'EO_0_1': tensor(0.)}
binary_groups_stat_rates¶
- torchmetrics.functional.classification.binary_groups_stat_rates(preds, target, groups, num_groups, threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Compute the true/false positives and true/false negatives rates for binary classification by group.
Related to Type I and Type II errors.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
.groups
(int tensor):(N, ...)
. The group identifiers should be0, 1, ..., (num_groups - 1)
.
The additional dimensions are flatted along the batch dimension.
- Parameters:
preds (
Tensor
) – Tensor with predictions.target (
Tensor
) – Tensor with true labels.groups (
Tensor
) – Tensor with group identifiers. The group identifiers should be0, 1, ..., (num_groups - 1)
.num_groups (
int
) – The number of groups.threshold (
float
) – Threshold for transforming probability to binary {0,1} predictions.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
The metric returns a dict with a group identifier as key and a tensor with the tp, fp, tn and fn rates as value.
- Example (preds is int tensor):
>>> from torchmetrics.functional.classification import binary_groups_stat_rates >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 1, 0, 1, 0, 1]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> binary_groups_stat_rates(preds, target, groups, 2) {'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_groups_stat_rates >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.84, 0.22, 0.73, 0.33, 0.92]) >>> groups = torch.tensor([0, 1, 0, 1, 0, 1]) >>> binary_groups_stat_rates(preds, target, groups, 2) {'group_0': tensor([0., 0., 1., 0.]), 'group_1': tensor([1., 0., 0., 0.])}
Hamming Distance¶
Module Interface¶
- class torchmetrics.HammingDistance(**kwargs)[source]¶
Compute the average Hamming distance (also known as Hamming loss).
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryHammingDistance
,MulticlassHammingDistance
andMultilabelHammingDistance
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([[0, 1], [1, 1]]) >>> preds = tensor([[0, 1], [0, 1]]) >>> hamming_distance = HammingDistance(task="multilabel", num_labels=2) >>> hamming_distance(preds, target) tensor(0.2500)
BinaryHammingDistance¶
- class torchmetrics.classification.BinaryHammingDistance(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the average Hamming distance (also known as Hamming loss) for binary tasks.
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:bhd
(Tensor
): A tensor whose returned shape depends on themultidim_average
arguments:If
multidim_average
is set toglobal
, the metric returns a scalar value.If
multidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryHammingDistance >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryHammingDistance() >>> metric(preds, target) tensor(0.3333)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryHammingDistance >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryHammingDistance() >>> metric(preds, target) tensor(0.3333)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryHammingDistance >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryHammingDistance(multidim_average='samplewise') >>> metric(preds, target) tensor([0.6667, 0.8333])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryHammingDistance >>> metric = BinaryHammingDistance() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryHammingDistance >>> metric = BinaryHammingDistance() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassHammingDistance¶
- class torchmetrics.classification.MulticlassHammingDistance(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the average Hamming distance (also known as Hamming loss) for multiclass tasks.
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:mchd
(Tensor
): A tensor whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassHammingDistance >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassHammingDistance(num_classes=3) >>> metric(preds, target) tensor(0.1667) >>> mchd = MulticlassHammingDistance(num_classes=3, average=None) >>> mchd(preds, target) tensor([0.5000, 0.0000, 0.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassHammingDistance >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassHammingDistance(num_classes=3) >>> metric(preds, target) tensor(0.1667) >>> mchd = MulticlassHammingDistance(num_classes=3, average=None) >>> mchd(preds, target) tensor([0.5000, 0.0000, 0.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassHammingDistance >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.5000, 0.7222]) >>> mchd = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise', average=None) >>> mchd(preds, target) tensor([[0.0000, 1.0000, 0.5000], [1.0000, 0.6667, 0.5000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value per class >>> from torch import randint >>> from torchmetrics.classification import MulticlassHammingDistance >>> metric = MulticlassHammingDistance(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting a multiple values per class >>> from torch import randint >>> from torchmetrics.classification import MulticlassHammingDistance >>> metric = MulticlassHammingDistance(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelHammingDistance¶
- class torchmetrics.classification.MultilabelHammingDistance(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the average Hamming distance (also known as Hamming loss) for multilabel tasks.
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
.
As output to
forward
andcompute
the metric returns the following output:mlhd
(Tensor
): A tensor whose returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelHammingDistance >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelHammingDistance(num_labels=3) >>> metric(preds, target) tensor(0.3333) >>> mlhd = MultilabelHammingDistance(num_labels=3, average=None) >>> mlhd(preds, target) tensor([0.0000, 0.5000, 0.5000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelHammingDistance >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelHammingDistance(num_labels=3) >>> metric(preds, target) tensor(0.3333) >>> mlhd = MultilabelHammingDistance(num_labels=3, average=None) >>> mlhd(preds, target) tensor([0.0000, 0.5000, 0.5000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelHammingDistance >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.6667, 0.8333]) >>> mlhd = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise', average=None) >>> mlhd(preds, target) tensor([[0.5000, 0.5000, 1.0000], [1.0000, 1.0000, 0.5000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelHammingDistance >>> metric = MultilabelHammingDistance(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelHammingDistance >>> metric = MultilabelHammingDistance(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
hamming_distance¶
- torchmetrics.functional.hamming_distance(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute the average Hamming distance (also known as Hamming loss). :rtype:
Tensor
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_hamming_distance()
,multiclass_hamming_distance()
andmultilabel_hamming_distance()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([[0, 1], [1, 1]]) >>> preds = tensor([[0, 1], [0, 1]]) >>> hamming_distance(preds, target, task="binary") tensor(0.2500)
binary_hamming_distance¶
- torchmetrics.functional.classification.binary_hamming_distance(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute the average Hamming distance (also known as Hamming loss) for binary tasks.
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_hamming_distance >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_hamming_distance(preds, target) tensor(0.3333)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_hamming_distance >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_hamming_distance(preds, target) tensor(0.3333)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_hamming_distance >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_hamming_distance(preds, target, multidim_average='samplewise') tensor([0.6667, 0.8333])
multiclass_hamming_distance¶
- torchmetrics.functional.classification.multiclass_hamming_distance(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute the average Hamming distance (also known as Hamming loss) for multiclass tasks.
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_hamming_distance >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_hamming_distance(preds, target, num_classes=3) tensor(0.1667) >>> multiclass_hamming_distance(preds, target, num_classes=3, average=None) tensor([0.5000, 0.0000, 0.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_hamming_distance >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_hamming_distance(preds, target, num_classes=3) tensor(0.1667) >>> multiclass_hamming_distance(preds, target, num_classes=3, average=None) tensor([0.5000, 0.0000, 0.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_hamming_distance >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise') tensor([0.5000, 0.7222]) >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[0.0000, 1.0000, 0.5000], [1.0000, 0.6667, 0.5000]])
multilabel_hamming_distance¶
- torchmetrics.functional.classification.multilabel_hamming_distance(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute the average Hamming distance (also known as Hamming loss) for multilabel tasks.
\[\text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il})\]Where \(y\) is a tensor of target values, \(\hat{y}\) is a tensor of predictions, and \(\bullet_{il}\) refers to the \(l\)-th label of the \(i\)-th sample of that tensor.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_hamming_distance >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_hamming_distance(preds, target, num_labels=3) tensor(0.3333) >>> multilabel_hamming_distance(preds, target, num_labels=3, average=None) tensor([0.0000, 0.5000, 0.5000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_hamming_distance >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_hamming_distance(preds, target, num_labels=3) tensor(0.3333) >>> multilabel_hamming_distance(preds, target, num_labels=3, average=None) tensor([0.0000, 0.5000, 0.5000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_hamming_distance >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise') tensor([0.6667, 0.8333]) >>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[0.5000, 0.5000, 1.0000], [1.0000, 1.0000, 0.5000]])
Hinge Loss¶
Module Interface¶
- class torchmetrics.HingeLoss(**kwargs)[source]¶
Compute the mean Hinge loss typically used for Support Vector Machines (SVMs).
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofBinaryHingeLoss
andMulticlassHingeLoss
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 1]) >>> preds = tensor([0.5, 0.7, 0.1]) >>> hinge = HingeLoss(task="binary") >>> hinge(preds, target) tensor(0.9000)
>>> target = tensor([0, 1, 2]) >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = HingeLoss(task="multiclass", num_classes=3) >>> hinge(preds, target) tensor(1.5551)
>>> target = tensor([0, 1, 2]) >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all") >>> hinge(preds, target) tensor([1.3743, 1.1945, 1.2359])
BinaryHingeLoss¶
- class torchmetrics.classification.BinaryHingeLoss(squared=False, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for binary tasks.
\[\text{Hinge loss} = \max(0, 1 - y \times \hat{y})\]Where \(y \in {-1, 1}\) is the target, and \(\hat{y} \in \mathbb{R}\) is the prediction.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:bhl
(Tensor
): A tensor containing the hinge loss.
- Parameters:
squared (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.classification import BinaryHingeLoss >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> bhl = BinaryHingeLoss() >>> bhl(preds, target) tensor(0.6900) >>> bhl = BinaryHingeLoss(squared=True) >>> bhl(preds, target) tensor(0.6905)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryHingeLoss >>> metric = BinaryHingeLoss() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryHingeLoss >>> metric = BinaryHingeLoss() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassHingeLoss¶
- class torchmetrics.classification.MulticlassHingeLoss(num_classes, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for multiclass tasks.
The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
\[\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)\]Where \(y \in {0, ..., \mathrm{C}}\) is the target class (where \(\mathrm{C}\) is the number of classes), and \(\hat{y} \in \mathbb{R}^\mathrm{C}\) is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mchl
(Tensor
): A tensor containing the multi-class hinge loss.
- Parameters:
num_classes (
int
) – Integer specifing the number of classessquared (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.multiclass_mode (
Literal
['crammer-singer'
,'one-vs-all'
]) – Determines how to compute the metricignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.classification import MulticlassHingeLoss >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = torch.tensor([0, 1, 2, 0]) >>> mchl = MulticlassHingeLoss(num_classes=3) >>> mchl(preds, target) tensor(0.9125) >>> mchl = MulticlassHingeLoss(num_classes=3, squared=True) >>> mchl(preds, target) tensor(1.1131) >>> mchl = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all') >>> mchl(preds, target) tensor([0.8750, 1.1250, 1.1000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value per class >>> from torch import randint, randn >>> from torchmetrics.classification import MulticlassHingeLoss >>> metric = MulticlassHingeLoss(num_classes=3) >>> metric.update(randn(20, 3), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting a multiple values per class >>> from torch import randint, randn >>> from torchmetrics.classification import MulticlassHingeLoss >>> metric = MulticlassHingeLoss(num_classes=3) >>> values = [] >>> for _ in range(20): ... values.append(metric(randn(20, 3), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.hinge_loss(preds, target, task, num_classes=None, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True)[source]¶
Compute the mean Hinge loss typically used for Support Vector Machines (SVMs).
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofbinary_hinge_loss()
andmulticlass_hinge_loss()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> target = tensor([0, 1, 1]) >>> preds = tensor([0.5, 0.7, 0.1]) >>> hinge_loss(preds, target, task="binary") tensor(0.9000)
>>> target = tensor([0, 1, 2]) >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge_loss(preds, target, task="multiclass", num_classes=3) tensor(1.5551)
>>> target = tensor([0, 1, 2]) >>> preds = tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") tensor([1.3743, 1.1945, 1.2359])
binary_hinge_loss¶
- torchmetrics.functional.classification.binary_hinge_loss(preds, target, squared=False, ignore_index=None, validate_args=False)[source]¶
Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for binary tasks.
\[\text{Hinge loss} = \max(0, 1 - y \times \hat{y})\]Where \(y \in {-1, 1}\) is the target, and \(\hat{y} \in \mathbb{R}\) is the prediction.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelssquared (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_hinge_loss >>> preds = tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = tensor([0, 0, 1, 1, 1]) >>> binary_hinge_loss(preds, target) tensor(0.6900) >>> binary_hinge_loss(preds, target, squared=True) tensor(0.6905)
multiclass_hinge_loss¶
- torchmetrics.functional.classification.multiclass_hinge_loss(preds, target, num_classes, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=False)[source]¶
Compute the mean Hinge loss typically used for Support Vector Machines (SVMs) for multiclass tasks.
The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
\[\text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right)\]Where \(y \in {0, ..., \mathrm{C}}\) is the target class (where \(\mathrm{C}\) is the number of classes), and \(\hat{y} \in \mathbb{R}^\mathrm{C}\) is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classessquared (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.multiclass_mode (
Literal
['crammer-singer'
,'one-vs-all'
]) – Determines how to compute the metricignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_hinge_loss >>> preds = tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = tensor([0, 1, 2, 0]) >>> multiclass_hinge_loss(preds, target, num_classes=3) tensor(0.9125) >>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True) tensor(1.1131) >>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all') tensor([0.8750, 1.1250, 1.1000])
Jaccard Index¶
Module Interface¶
- class torchmetrics.JaccardIndex(**kwargs)[source]¶
Calculate the Jaccard index for multilabel tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryJaccardIndex
,MulticlassJaccardIndex
andMultilabelJaccardIndex
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import randint, tensor >>> target = randint(0, 2, (10, 25, 25)) >>> pred = tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> jaccard = JaccardIndex(task="multiclass", num_classes=2) >>> jaccard(pred, target) tensor(0.9660)
BinaryJaccardIndex¶
- class torchmetrics.classification.BinaryJaccardIndex(threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]¶
Calculate the Jaccard index for binary tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:bji
(Tensor
): A tensor containing the Binary Jaccard Index.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryJaccardIndex >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> metric = BinaryJaccardIndex() >>> metric(preds, target) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryJaccardIndex >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryJaccardIndex() >>> metric(preds, target) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryJaccardIndex >>> metric = BinaryJaccardIndex() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import BinaryJaccardIndex >>> metric = BinaryJaccardIndex() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassJaccardIndex¶
- class torchmetrics.classification.MulticlassJaccardIndex(num_classes, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]¶
Calculate the Jaccard index for multiclass tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mcji
(Tensor
): A tensor containing the Multi-class Jaccard Index.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassJaccardIndex >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassJaccardIndex(num_classes=3) >>> metric(preds, target) tensor(0.6667)
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassJaccardIndex >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassJaccardIndex(num_classes=3) >>> metric(preds, target) tensor(0.6667)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value per class >>> from torch import randint >>> from torchmetrics.classification import MulticlassJaccardIndex >>> metric = MulticlassJaccardIndex(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting a multiple values per class >>> from torch import randint >>> from torchmetrics.classification import MulticlassJaccardIndex >>> metric = MulticlassJaccardIndex(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelJaccardIndex¶
- class torchmetrics.classification.MultilabelJaccardIndex(num_labels, threshold=0.5, average='macro', ignore_index=None, validate_args=True, **kwargs)[source]¶
Calculate the Jaccard index for multilabel tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int tensor or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mlji
(Tensor
): A tensor containing the Multi-label Jaccard Index loss.
- Parameters:
num_classes – Integer specifing the number of labels
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelJaccardIndex >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelJaccardIndex(num_labels=3) >>> metric(preds, target) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelJaccardIndex >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelJaccardIndex(num_labels=3) >>> metric(preds, target) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelJaccardIndex >>> metric = MultilabelJaccardIndex(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelJaccardIndex >>> metric = MultilabelJaccardIndex(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
jaccard_index¶
- torchmetrics.functional.jaccard_index(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='macro', ignore_index=None, validate_args=True)[source]¶
Calculate the Jaccard index.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets: :rtype:
Tensor
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_jaccard_index()
,multiclass_jaccard_index()
andmultilabel_jaccard_index()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import randint, tensor >>> target = randint(0, 2, (10, 25, 25)) >>> pred = tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> jaccard_index(pred, target, task="multiclass", num_classes=2) tensor(0.9660)
binary_jaccard_index¶
- torchmetrics.functional.classification.binary_jaccard_index(preds, target, threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Calculate the Jaccard index for binary tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_jaccard_index >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_jaccard_index(preds, target) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_jaccard_index >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_jaccard_index(preds, target) tensor(0.5000)
multiclass_jaccard_index¶
- torchmetrics.functional.classification.multiclass_jaccard_index(preds, target, num_classes, average='macro', ignore_index=None, validate_args=True)[source]¶
Calculate the Jaccard index for multiclass tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_jaccard_index >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_jaccard_index(preds, target, num_classes=3) tensor(0.6667)
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_jaccard_index >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_jaccard_index(preds, target, num_classes=3) tensor(0.6667)
multilabel_jaccard_index¶
- torchmetrics.functional.classification.multilabel_jaccard_index(preds, target, num_labels, threshold=0.5, average='macro', ignore_index=None, validate_args=True)[source]¶
Calculate the Jaccard index for multilabel tasks.
The Jaccard index (also known as the intersetion over union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample sets:
\[J(A,B) = \frac{|A\cap B|}{|A\cup B|}\]Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_jaccard_index >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_jaccard_index(preds, target, num_labels=3) tensor(0.5000)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_jaccard_index >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_jaccard_index(preds, target, num_labels=3) tensor(0.5000)
Label Ranking Average Precision¶
Module Interface¶
- class torchmetrics.classification.MultilabelRankingAveragePrecision(num_labels, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute label ranking average precision score for multilabel data [1].
The score is the average over each ground truth label assigned to each sample of the ratio of true vs. total labels with lower score. Best score is 1.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mlrap
(Tensor
): A tensor containing the multilabel ranking average precision.
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
Example
>>> from torchmetrics.classification import MultilabelRankingAveragePrecision >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) >>> mlrap = MultilabelRankingAveragePrecision(num_labels=5) >>> mlrap(preds, target) tensor(0.7744)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelRankingAveragePrecision >>> metric = MultilabelRankingAveragePrecision(num_labels=3) >>> metric.update(rand(20, 3), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelRankingAveragePrecision >>> metric = MultilabelRankingAveragePrecision(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(20, 3), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.classification.multilabel_ranking_average_precision(preds, target, num_labels, ignore_index=None, validate_args=True)[source]¶
Compute label ranking average precision score for multilabel data [1].
The score is the average over each ground truth label assigned to each sample of the ratio of true vs. total labels with lower score. Best score is 1.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import multilabel_ranking_average_precision >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) >>> multilabel_ranking_average_precision(preds, target, num_labels=5) tensor(0.7744)
References
[1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US.
Label Ranking Loss¶
Module Interface¶
- class torchmetrics.classification.MultilabelRankingLoss(num_labels, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the label ranking loss for multilabel data [1].
The score is corresponds to the average number of label pairs that are incorrectly ordered given some predictions weighted by the size of the label set and the number of labels not in the label set. The best score is 0.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mlrl
(Tensor
): A tensor containing the multilabel ranking loss.
- Parameters:
preds – Tensor with predictions
target – Tensor with true labels
num_labels (
int
) – Integer specifing the number of labelsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
Example
>>> from torchmetrics.classification import MultilabelRankingLoss >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) >>> mlrl = MultilabelRankingLoss(num_labels=5) >>> mlrl(preds, target) tensor(0.4167)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelRankingLoss >>> metric = MultilabelRankingLoss(num_labels=3) >>> metric.update(rand(20, 3), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelRankingLoss >>> metric = MultilabelRankingLoss(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(20, 3), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.classification.multilabel_ranking_loss(preds, target, num_labels, ignore_index=None, validate_args=True)[source]¶
Compute the label ranking loss for multilabel data [1].
The score is corresponds to the average number of label pairs that are incorrectly ordered given some predictions weighted by the size of the label set and the number of labels not in the label set. The best score is 0.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
Example
>>> from torchmetrics.functional.classification import multilabel_ranking_loss >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) >>> multilabel_ranking_loss(preds, target, num_labels=5) tensor(0.4167)
References
[1] Tsoumakas, G., Katakis, I., & Vlahavas, I. (2010). Mining multi-label data. In Data mining and knowledge discovery handbook (pp. 667-685). Springer US.
Matthews Correlation Coefficient¶
Module Interface¶
- class torchmetrics.MatthewsCorrCoef(**kwargs)[source]¶
Calculate Matthews correlation coefficient .
This metric measures the general correlation or quality of a classification.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryMatthewsCorrCoef
,MulticlassMatthewsCorrCoef
andMultilabelMatthewsCorrCoef
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> matthews_corrcoef = MatthewsCorrCoef(task='binary') >>> matthews_corrcoef(preds, target) tensor(0.5774)
BinaryMatthewsCorrCoef¶
- class torchmetrics.classification.BinaryMatthewsCorrCoef(threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]¶
Calculate Matthews correlation coefficient for binary tasks.
This metric measures the general correlation or quality of a classification.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int tensor or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:bmcc
(Tensor
): A tensor containing the Binary Matthews Correlation Coefficient.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryMatthewsCorrCoef >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> metric = BinaryMatthewsCorrCoef() >>> metric(preds, target) tensor(0.5774)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryMatthewsCorrCoef >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryMatthewsCorrCoef() >>> metric(preds, target) tensor(0.5774)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryMatthewsCorrCoef >>> metric = BinaryMatthewsCorrCoef() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryMatthewsCorrCoef >>> metric = BinaryMatthewsCorrCoef() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassMatthewsCorrCoef¶
- class torchmetrics.classification.MulticlassMatthewsCorrCoef(num_classes, ignore_index=None, validate_args=True, **kwargs)[source]¶
Calculate Matthews correlation coefficient for multiclass tasks.
This metric measures the general correlation or quality of a classification.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mcmcc
(Tensor
): A tensor containing the Multi-class Matthews Correlation Coefficient.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) >>> metric(preds, target) tensor(0.7000)
- Example (pred is float tensor):
>>> from torchmetrics.classification import MulticlassMatthewsCorrCoef >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) >>> metric(preds, target) tensor(0.7000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelMatthewsCorrCoef¶
- class torchmetrics.classification.MultilabelMatthewsCorrCoef(num_labels, threshold=0.5, ignore_index=None, validate_args=True, **kwargs)[source]¶
Calculate Matthews correlation coefficient for multilabel tasks.
This metric measures the general correlation or quality of a classification.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mlmcc
(Tensor
): A tensor containing the Multi-label Matthews Correlation Coefficient.
- Parameters:
num_classes – Integer specifing the number of labels
threshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) >>> metric(preds, target) tensor(0.3333)
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelMatthewsCorrCoef >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) >>> metric(preds, target) tensor(0.3333)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
matthews_corrcoef¶
- torchmetrics.functional.matthews_corrcoef(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, ignore_index=None, validate_args=True)[source]¶
Calculate Matthews correlation coefficient .
This metric measures the general correlation or quality of a classification.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_matthews_corrcoef()
,multiclass_matthews_corrcoef()
andmultilabel_matthews_corrcoef()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2) tensor(0.5774)
binary_matthews_corrcoef¶
- torchmetrics.functional.classification.binary_matthews_corrcoef(preds, target, threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Calculate Matthews correlation coefficient for binary tasks.
This metric measures the general correlation or quality of a classification.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_matthews_corrcoef >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0, 1, 0, 0]) >>> binary_matthews_corrcoef(preds, target) tensor(0.5774)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_matthews_corrcoef >>> target = tensor([1, 1, 0, 0]) >>> preds = tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_matthews_corrcoef(preds, target) tensor(0.5774)
multiclass_matthews_corrcoef¶
- torchmetrics.functional.classification.multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index=None, validate_args=True)[source]¶
Calculate Matthews correlation coefficient for multiclass tasks.
This metric measures the general correlation or quality of a classification.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs – Additional keyword arguments, see Advanced metric settings for more info.
- Return type:
- Example (pred is integer tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) tensor(0.7000)
- Example (pred is float tensor):
>>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) tensor(0.7000)
multilabel_matthews_corrcoef¶
- torchmetrics.functional.classification.multilabel_matthews_corrcoef(preds, target, num_labels, threshold=0.5, ignore_index=None, validate_args=True)[source]¶
Calculate Matthews correlation coefficient for multilabel tasks.
This metric measures the general correlation or quality of a classification.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
Additional dimension
...
will be flattened into the batch dimension.- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) tensor(0.3333)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) tensor(0.3333)
Precision¶
Module Interface¶
- class torchmetrics.Precision(**kwargs)[source]¶
Compute Precision.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryPrecision
,MulticlassPrecision
andMultilabelPrecision
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> precision = Precision(task="multiclass", average='macro', num_classes=3) >>> precision(preds, target) tensor(0.1667) >>> precision = Precision(task="multiclass", average='micro', num_classes=3) >>> precision(preds, target) tensor(0.2500)
BinaryPrecision¶
- class torchmetrics.classification.BinaryPrecision(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Precision for binary tasks.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered a score of 0 is returned.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:bp
(Tensor
): Ifmultidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryPrecision >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryPrecision() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryPrecision >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryPrecision() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryPrecision >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryPrecision(multidim_average='samplewise') >>> metric(preds, target) tensor([0.4000, 0.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryPrecision >>> metric = BinaryPrecision() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryPrecision >>> metric = BinaryPrecision() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassPrecision¶
- class torchmetrics.classification.MulticlassPrecision(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Precision for multiclass tasks.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
.
As output to
forward
andcompute
the metric returns the following output:mcp
(Tensor
): The returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassPrecision >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassPrecision(num_classes=3) >>> metric(preds, target) tensor(0.8333) >>> mcp = MulticlassPrecision(num_classes=3, average=None) >>> mcp(preds, target) tensor([1.0000, 0.5000, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassPrecision >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassPrecision(num_classes=3) >>> metric(preds, target) tensor(0.8333) >>> mcp = MulticlassPrecision(num_classes=3, average=None) >>> mcp(preds, target) tensor([1.0000, 0.5000, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassPrecision >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassPrecision(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.3889, 0.2778]) >>> mcp = MulticlassPrecision(num_classes=3, multidim_average='samplewise', average=None) >>> mcp(preds, target) tensor([[0.6667, 0.0000, 0.5000], [0.0000, 0.5000, 0.3333]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassPrecision >>> metric = MulticlassPrecision(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassPrecision >>> metric = MulticlassPrecision(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelPrecision¶
- class torchmetrics.classification.MultilabelPrecision(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Precision for multilabel tasks.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
.
As output to
forward
andcompute
the metric returns the following output:mlp
(Tensor
): The returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelPrecision >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelPrecision(num_labels=3) >>> metric(preds, target) tensor(0.5000) >>> mlp = MultilabelPrecision(num_labels=3, average=None) >>> mlp(preds, target) tensor([1.0000, 0.0000, 0.5000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelPrecision >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelPrecision(num_labels=3) >>> metric(preds, target) tensor(0.5000) >>> mlp = MultilabelPrecision(num_labels=3, average=None) >>> mlp(preds, target) tensor([1.0000, 0.0000, 0.5000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelPrecision >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelPrecision(num_labels=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.3333, 0.0000]) >>> mlp = MultilabelPrecision(num_labels=3, multidim_average='samplewise', average=None) >>> mlp(preds, target) tensor([[0.5000, 0.5000, 0.0000], [0.0000, 0.0000, 0.0000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelPrecision >>> metric = MultilabelPrecision(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelPrecision >>> metric = MultilabelPrecision(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.precision(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute Precision. :rtype:
Tensor
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_precision()
,multiclass_precision()
andmultilabel_precision()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> precision(preds, target, task="multiclass", average='macro', num_classes=3) tensor(0.1667) >>> precision(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.2500)
binary_precision¶
- torchmetrics.functional.classification.binary_precision(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Precision for binary tasks.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_precision >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_precision(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_precision >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_precision(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_precision >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_precision(preds, target, multidim_average='samplewise') tensor([0.4000, 0.0000])
multiclass_precision¶
- torchmetrics.functional.classification.multiclass_precision(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Precision for multiclass tasks.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_precision >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_precision(preds, target, num_classes=3) tensor(0.8333) >>> multiclass_precision(preds, target, num_classes=3, average=None) tensor([1.0000, 0.5000, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_precision >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_precision(preds, target, num_classes=3) tensor(0.8333) >>> multiclass_precision(preds, target, num_classes=3, average=None) tensor([1.0000, 0.5000, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_precision >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise') tensor([0.3889, 0.2778]) >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[0.6667, 0.0000, 0.5000], [0.0000, 0.5000, 0.3333]])
multilabel_precision¶
- torchmetrics.functional.classification.multilabel_precision(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Precision for multilabel tasks.
\[\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}\]Where \(\text{TP}\) and \(\text{FP}\) represent the number of true positives and false positives respecitively.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_precision >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_precision(preds, target, num_labels=3) tensor(0.5000) >>> multilabel_precision(preds, target, num_labels=3, average=None) tensor([1.0000, 0.0000, 0.5000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_precision >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_precision(preds, target, num_labels=3) tensor(0.5000) >>> multilabel_precision(preds, target, num_labels=3, average=None) tensor([1.0000, 0.0000, 0.5000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_precision >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise') tensor([0.3333, 0.0000]) >>> multilabel_precision(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[0.5000, 0.5000, 0.0000], [0.0000, 0.0000, 0.0000]])
Precision At Fixed Recall¶
Module Interface¶
- class torchmetrics.PrecisionAtFixedRecall(**kwargs)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryPrecisionAtFixedRecall
,MulticlassPrecisionAtFixedRecall
andMultilabelPrecisionAtFixedRecall
for the specific details of each argument influence and examples.
BinaryPrecisionAtFixedRecall¶
- class torchmetrics.classification.BinaryPrecisionAtFixedRecall(min_recall, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the highest possible precision value given the minimum recall thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:precision
(Tensor
): A scalar tensor with the maximum precision for the given recall levelthreshold
(Tensor
): A scalar tensor with the corresponding threshold level
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).- Parameters:
min_recall (
float
) – float value specifying minimum recall threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=None) >>> metric(preds, target) (tensor(0.6667), tensor(0.5000)) >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5, thresholds=5) >>> metric(preds, target) (tensor(0.6667), tensor(0.5000))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5) >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryPrecisionAtFixedRecall >>> metric = BinaryPrecisionAtFixedRecall(min_recall=0.5) >>> values = [ ] >>> for _ in range(10): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(10), randint(2,(10,)))[0]) >>> fig_, ax_ = metric.plot(values)
MulticlassPrecisionAtFixedRecall¶
- class torchmetrics.classification.MulticlassPrecisionAtFixedRecall(num_classes, min_recall, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the highest possible precision value given the minimum recall thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of either 2 tensors or 2 lists containing:precision
(Tensor
): A 1d tensor of size(n_classes, )
with the maximum precision for the given recall level per classthreshold
(Tensor
): A 1d tensor of size(n_classes, )
with the corresponding threshold level per class
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).- Parameters:
num_classes (
int
) – Integer specifing the number of classesmin_recall (
float
) – float value specifying minimum recall threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=None) >>> metric(preds, target) (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) >>> mcrafp = MulticlassPrecisionAtFixedRecall(num_classes=5, min_recall=0.5, thresholds=5) >>> mcrafp(preds, target) (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5) >>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
>>> from torch import rand, randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassPrecisionAtFixedRecall >>> metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5) >>> values = [] >>> for _ in range(20): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0]) >>> fig_, ax_ = metric.plot(values)
MultilabelPrecisionAtFixedRecall¶
- class torchmetrics.classification.MultilabelPrecisionAtFixedRecall(num_labels, min_recall, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the highest possible precision value given the minimum recall thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of either 2 tensors or 2 lists containing:precision
(Tensor
): A 1d tensor of size(n_classes, )
with the maximum precision for the given recall level per classthreshold
(Tensor
): A 1d tensor of size(n_classes, )
with the corresponding threshold level per class
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).- Parameters:
num_labels (
int
) – Integer specifing the number of labelsmin_recall (
float
) – float value specifying minimum recall threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=None) >>> metric(preds, target) (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5500, 0.3500])) >>> mlrafp = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5, thresholds=5) >>> mlrafp(preds, target) (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5) >>> metric.update(rand(20, 3), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelPrecisionAtFixedRecall >>> metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5) >>> values = [ ] >>> for _ in range(10): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(20, 3), randint(2, (20, 3)))[0]) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
binary_precision_at_fixed_recall¶
- torchmetrics.functional.classification.binary_precision_at_fixed_recall(preds, target, min_recall, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the highest possible precision value given the minimum recall thresholds provided for binary tasks.
This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsmin_recall (
float
) – float value specifying minimum recall threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of 2 tensors containing:
precision: an scalar tensor with the maximum precision for the given precision level
threshold: an scalar tensor with the corresponding threshold level
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import binary_precision_at_fixed_recall >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_precision_at_fixed_recall(preds, target, min_recall=0.5, thresholds=None) (tensor(0.6667), tensor(0.5000)) >>> binary_precision_at_fixed_recall(preds, target, min_recall=0.5, thresholds=5) (tensor(0.6667), tensor(0.5000))
multiclass_precision_at_fixed_recall¶
- torchmetrics.functional.classification.multiclass_precision_at_fixed_recall(preds, target, num_classes, min_recall, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the highest possible precision value given the minimum recall thresholds provided for multiclass tasks.
This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesmin_recall (
float
) – float value specifying minimum recall threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 2 tensors or 2 lists containing
precision: an 1d tensor of size (n_classes, ) with the maximum precision for the given recall level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multiclass_precision_at_fixed_recall >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_precision_at_fixed_recall( ... preds, target, num_classes=5, min_recall=0.5, thresholds=None) (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) >>> multiclass_precision_at_fixed_recall( ... preds, target, num_classes=5, min_recall=0.5, thresholds=5) (tensor([1.0000, 1.0000, 0.2500, 0.2500, 0.0000]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
multilabel_precision_at_fixed_recall¶
- torchmetrics.functional.classification.multilabel_precision_at_fixed_recall(preds, target, num_labels, min_recall, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the highest possible precision value given the minimum recall thresholds provided for multilabel tasks.
This is done by first calculating the precision-recall curve for different thresholds and the find the precision for a given recall level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsmin_recall (
float
) – float value specifying minimum recall threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 2 tensors or 2 lists containing
precision: an 1d tensor of size (n_classes, ) with the maximum precision for the given recall level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multilabel_precision_at_fixed_recall >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_precision_at_fixed_recall(preds, target, num_labels=3, min_recall=0.5, thresholds=None) (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5500, 0.3500])) >>> multilabel_precision_at_fixed_recall(preds, target, num_labels=3, min_recall=0.5, thresholds=5) (tensor([1.0000, 0.6667, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
Precision Recall Curve¶
Module Interface¶
- class torchmetrics.PrecisionRecallCurve(**kwargs)[source]¶
Compute the precision-recall curve.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryPrecisionRecallCurve
,MulticlassPrecisionRecallCurve
andMultilabelPrecisionRecallCurve
for the specific details of each argument influence and examples.- Legacy Example:
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) >>> pr_curve = PrecisionRecallCurve(task="binary") >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.5000, 0.6667, 0.5000, 1.0000, 1.0000]) >>> recall tensor([1.0000, 1.0000, 0.5000, 0.5000, 0.0000]) >>> thresholds tensor([0.0000, 0.1000, 0.4000, 0.8000])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> pr_curve = PrecisionRecallCurve(task="multiclass", num_classes=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision [tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds [tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)]
BinaryPrecisionRecallCurve¶
- class torchmetrics.classification.BinaryPrecisionRecallCurve(thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the precision-recall curve for binary tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:precision
(Tensor
): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds+1, )
with precision values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size(n_classes, n_thresholds+1)
with precision values is returned.recall
(Tensor
): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds+1, )
with recall values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size(n_classes, n_thresholds+1)
with recall values is returned.thresholds
(Tensor
): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds, )
with increasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size(n_thresholds, )
is returned with shared threshold values for all classes.
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.classification import BinaryPrecisionRecallCurve >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> bprc = BinaryPrecisionRecallCurve(thresholds=None) >>> bprc(preds, target) (tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]), tensor([0.0000, 0.5000, 0.7000, 0.8000])) >>> bprc = BinaryPrecisionRecallCurve(thresholds=5) >>> bprc(preds, target) (tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]), tensor([1., 1., 1., 0., 0., 0.]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
- plot(curve=None, score=None, ax=None)[source]¶
Plot a single curve from the metric.
- Parameters:
curve (
Optional
[Tuple
[Tensor
,Tensor
,Tensor
]]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.score (
Union
[Tensor
,bool
,None
]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> from torchmetrics.classification import BinaryPrecisionRecallCurve >>> preds = rand(20) >>> target = randint(2, (20,)) >>> metric = BinaryPrecisionRecallCurve() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot(score=True)
MulticlassPrecisionRecallCurve¶
- class torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the precision-recall curve for multiclass tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:precision
(Tensor
): A 1d tensor of size(n_thresholds+1, )
with precision valuesrecall
(Tensor
): A 1d tensor of size(n_thresholds+1, )
with recall valuesthresholds
(Tensor
): A 1d tensor of size(n_thresholds, )
with increasing threshold values
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
num_classes (
int
) – Integer specifing the number of classesthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to a 1D tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.classification import MulticlassPrecisionRecallCurve >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> mcprc = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=None) >>> precision, recall, thresholds = mcprc(preds, target) >>> precision [tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds [tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)] >>> mcprc = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=5) >>> mcprc(preds, target) (tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), tensor([[1., 1., 1., 1., 0., 0.], [1., 1., 1., 1., 0., 0.], [1., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
- plot(curve=None, score=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
curve (
Union
[Tuple
[Tensor
,Tensor
,Tensor
],Tuple
[List
[Tensor
],List
[Tensor
],List
[Tensor
]],None
]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.score (
Union
[Tensor
,bool
,None
]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn, randint >>> from torchmetrics.classification import MulticlassPrecisionRecallCurve >>> preds = randn(20, 3).softmax(dim=-1) >>> target = randint(3, (20,)) >>> metric = MulticlassPrecisionRecallCurve(num_classes=3) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot(score=True)
MultilabelPrecisionRecallCurve¶
- class torchmetrics.classification.MultilabelPrecisionRecallCurve(num_labels, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the precision-recall curve for multilabel tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following a tuple of either 3 tensors or 3 lists containing:precision
(Tensor
orList
): if thresholds=None a list for each label is returned with an 1d tensor of size(n_thresholds+1, )
with precision values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size(n_labels, n_thresholds+1)
with precision values is returned.recall
(Tensor
orList
): if thresholds=None a list for each label is returned with an 1d tensor of size(n_thresholds+1, )
with recall values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size(n_labels, n_thresholds+1)
with recall values is returned.thresholds
(Tensor
orList
): if thresholds=None a list for each label is returned with an 1d tensor of size(n_thresholds, )
with increasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size(n_thresholds, )
is returned with shared threshold values for all labels.
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
preds – Tensor with predictions
target – Tensor with true labels
num_labels (
int
) – Integer specifing the number of labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
Example
>>> from torchmetrics.classification import MultilabelPrecisionRecallCurve >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> mlprc = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=None) >>> precision, recall, thresholds = mlprc(preds, target) >>> precision [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]), tensor([0.7500, 1.0000, 1.0000, 1.0000])] >>> recall [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]), tensor([1.0000, 0.6667, 0.3333, 0.0000])] >>> thresholds [tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])] >>> mlprc = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=5) >>> mlprc(preds, target) (tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000], [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000], [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]), tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000], [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000], [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
- plot(curve=None, score=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
curve (
Union
[Tuple
[Tensor
,Tensor
,Tensor
],Tuple
[List
[Tensor
],List
[Tensor
],List
[Tensor
]],None
]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.score (
Union
[Tensor
,bool
,None
]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelPrecisionRecallCurve >>> preds = rand(20, 3) >>> target = randint(2, (20,3)) >>> metric = MultilabelPrecisionRecallCurve(num_labels=3) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot(score=True)
Functional Interface¶
- torchmetrics.functional.precision_recall_curve(preds, target, task, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True)[source]¶
Compute the precision-recall curve.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_precision_recall_curve()
,multiclass_precision_recall_curve()
andmultilabel_precision_recall_curve()
for the specific details of each argument influence and examples.- Legacy Example:
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) >>> precision, recall, thresholds = precision_recall_curve(pred, target, task='binary') >>> precision tensor([0.5000, 0.6667, 0.5000, 1.0000, 1.0000]) >>> recall tensor([1.0000, 1.0000, 0.5000, 0.5000, 0.0000]) >>> thresholds tensor([0.0000, 0.1000, 0.4000, 0.8000])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> precision, recall, thresholds = precision_recall_curve(pred, target, task='multiclass', num_classes=5) >>> precision [tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds [tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
binary_precision_recall_curve¶
- torchmetrics.functional.classification.binary_precision_recall_curve(preds, target, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the precision-recall curve for binary tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of 3 tensors containing:
precision: an 1d tensor of size (n_thresholds+1, ) with precision values
recall: an 1d tensor of size (n_thresholds+1, ) with recall values
thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import binary_precision_recall_curve >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_precision_recall_curve(preds, target, thresholds=None) (tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]), tensor([0.0000, 0.5000, 0.7000, 0.8000])) >>> binary_precision_recall_curve(preds, target, thresholds=5) (tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]), tensor([1., 1., 1., 0., 0., 0.]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
multiclass_precision_recall_curve¶
- torchmetrics.functional.classification.multiclass_precision_recall_curve(preds, target, num_classes, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the precision-recall curve for multiclass tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 3 tensors or 3 lists containing
precision: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with precision values is returned.
recall: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with recall values is returned.
thresholds: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multiclass_precision_recall_curve >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> precision, recall, thresholds = multiclass_precision_recall_curve( ... preds, target, num_classes=5, thresholds=None ... ) >>> precision [tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds [tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] >>> multiclass_precision_recall_curve( ... preds, target, num_classes=5, thresholds=5 ... ) (tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], [0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000], [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.2500, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), tensor([[1., 1., 1., 1., 0., 0.], [1., 1., 1., 1., 0., 0.], [1., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.]]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
multilabel_precision_recall_curve¶
- torchmetrics.functional.classification.multilabel_precision_recall_curve(preds, target, num_labels, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the precision-recall curve for multilabel tasks.
The curve consist of multiple pairs of precision and recall values evaluated at different thresholds, such that the tradeoff between the two values can been seen.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 3 tensors or 3 lists containing
precision: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with precision values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with precision values is returned.
recall: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with recall values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with recall values is returned.
thresholds: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds, ) with increasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multilabel_precision_recall_curve >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> precision, recall, thresholds = multilabel_precision_recall_curve( ... preds, target, num_labels=3, thresholds=None ... ) >>> precision [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]), tensor([0.7500, 1.0000, 1.0000, 1.0000])] >>> recall [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]), tensor([1.0000, 0.6667, 0.3333, 0.0000])] >>> thresholds [tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])] >>> multilabel_precision_recall_curve( ... preds, target, num_labels=3, thresholds=5 ... ) (tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000], [0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000], [0.7500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000]]), tensor([[1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000], [1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000], [1.0000, 0.6667, 0.3333, 0.3333, 0.0000, 0.0000]]), tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]))
Recall¶
Module Interface¶
- class torchmetrics.Recall(**kwargs)[source]¶
Compute Recall.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryRecall
,MulticlassRecall
andMultilabelRecall
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> recall = Recall(task="multiclass", average='macro', num_classes=3) >>> recall(preds, target) tensor(0.3333) >>> recall = Recall(task="multiclass", average='micro', num_classes=3) >>> recall(preds, target) tensor(0.2500)
BinaryRecall¶
- class torchmetrics.classification.BinaryRecall(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Recall for binary tasks.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered a score of 0 is returned.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:br
(Tensor
): Ifmultidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryRecall >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryRecall() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryRecall >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryRecall() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryRecall >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryRecall(multidim_average='samplewise') >>> metric(preds, target) tensor([0.6667, 0.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryRecall >>> metric = BinaryRecall() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryRecall >>> metric = BinaryRecall() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassRecall¶
- class torchmetrics.classification.MulticlassRecall(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Recall for multiclass tasks.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:mcr
(Tensor
): The returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassRecall >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassRecall(num_classes=3) >>> metric(preds, target) tensor(0.8333) >>> mcr = MulticlassRecall(num_classes=3, average=None) >>> mcr(preds, target) tensor([0.5000, 1.0000, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassRecall >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassRecall(num_classes=3) >>> metric(preds, target) tensor(0.8333) >>> mcr = MulticlassRecall(num_classes=3, average=None) >>> mcr(preds, target) tensor([0.5000, 1.0000, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassRecall >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassRecall(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.5000, 0.2778]) >>> mcr = MulticlassRecall(num_classes=3, multidim_average='samplewise', average=None) >>> mcr(preds, target) tensor([[1.0000, 0.0000, 0.5000], [0.0000, 0.3333, 0.5000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassRecall >>> metric = MulticlassRecall(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassRecall >>> metric = MulticlassRecall(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelRecall¶
- class torchmetrics.classification.MultilabelRecall(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Recall for multilabel tasks.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respectively. The metric is only proper defined when \(\text{TP} + \text{FN} \neq 0\). If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
As output to
forward
andcompute
the metric returns the following output:mlr
(Tensor
): The returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelRecall >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelRecall(num_labels=3) >>> metric(preds, target) tensor(0.6667) >>> mlr = MultilabelRecall(num_labels=3, average=None) >>> mlr(preds, target) tensor([1., 0., 1.])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelRecall >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelRecall(num_labels=3) >>> metric(preds, target) tensor(0.6667) >>> mlr = MultilabelRecall(num_labels=3, average=None) >>> mlr(preds, target) tensor([1., 0., 1.])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelRecall >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelRecall(num_labels=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.6667, 0.0000]) >>> mlr = MultilabelRecall(num_labels=3, multidim_average='samplewise', average=None) >>> mlr(preds, target) tensor([[1., 1., 0.], [0., 0., 0.]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelRecall >>> metric = MultilabelRecall(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelRecall >>> metric = MultilabelRecall(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.recall(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute Recall. :rtype:
Tensor
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_recall()
,multiclass_recall()
andmultilabel_recall()
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> recall(preds, target, task="multiclass", average='macro', num_classes=3) tensor(0.3333) >>> recall(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.2500)
binary_recall¶
- torchmetrics.functional.classification.binary_recall(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Recall for binary tasks.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_recall >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_recall(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_recall >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_recall(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_recall >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_recall(preds, target, multidim_average='samplewise') tensor([0.6667, 0.0000])
multiclass_recall¶
- torchmetrics.functional.classification.multiclass_recall(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Recall for multiclass tasks.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_recall >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_recall(preds, target, num_classes=3) tensor(0.8333) >>> multiclass_recall(preds, target, num_classes=3, average=None) tensor([0.5000, 1.0000, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_recall >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_recall(preds, target, num_classes=3) tensor(0.8333) >>> multiclass_recall(preds, target, num_classes=3, average=None) tensor([0.5000, 1.0000, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_recall >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise') tensor([0.5000, 0.2778]) >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[1.0000, 0.0000, 0.5000], [0.0000, 0.3333, 0.5000]])
multilabel_recall¶
- torchmetrics.functional.classification.multilabel_recall(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Recall for multilabel tasks.
\[\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}\]Where \(\text{TP}\) and \(\text{FN}\) represent the number of true positives and false negatives respecitively.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_recall >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_recall(preds, target, num_labels=3) tensor(0.6667) >>> multilabel_recall(preds, target, num_labels=3, average=None) tensor([1., 0., 1.])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_recall >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_recall(preds, target, num_labels=3) tensor(0.6667) >>> multilabel_recall(preds, target, num_labels=3, average=None) tensor([1., 0., 1.])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_recall >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise') tensor([0.6667, 0.0000]) >>> multilabel_recall(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[1., 1., 0.], [0., 0., 0.]])
Recall At Fixed Precision¶
Module Interface¶
- class torchmetrics.RecallAtFixedPrecision(**kwargs)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryRecallAtFixedPrecision
,MulticlassRecallAtFixedPrecision
andMultilabelRecallAtFixedPrecision
for the specific details of each argument influence and examples.
BinaryRecallAtFixedPrecision¶
- class torchmetrics.classification.BinaryRecallAtFixedPrecision(min_precision, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:recall
(Tensor
): A scalar tensor with the maximum recall for the given precision levelthreshold
(Tensor
): A scalar tensor with the corresponding threshold level
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).- Parameters:
min_precision (
float
) – float value specifying minimum precision threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=None) >>> metric(preds, target) (tensor(1.), tensor(0.5000)) >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=5) >>> metric(preds, target) (tensor(1.), tensor(0.5000))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5) >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5) >>> values = [ ] >>> for _ in range(10): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(10), randint(2,(10,)))[0]) >>> fig_, ax_ = metric.plot(values)
MulticlassRecallAtFixedPrecision¶
- class torchmetrics.classification.MulticlassRecallAtFixedPrecision(num_classes, min_precision, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of either 2 tensors or 2 lists containing:recall
(Tensor
): A 1d tensor of size(n_classes, )
with the maximum recall for the given precision level per classthreshold
(Tensor
): A 1d tensor of size(n_classes, )
with the corresponding threshold level per class
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).- Parameters:
num_classes (
int
) – Integer specifing the number of classesmin_precision (
float
) – float value specifying minimum precision threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=None) >>> metric(preds, target) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) >>> mcrafp = MulticlassRecallAtFixedPrecision(num_classes=5, min_precision=0.5, thresholds=5) >>> mcrafp(preds, target) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision >>> metric = MulticlassRecallAtFixedPrecision(num_classes=3, min_precision=0.5) >>> metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
>>> from torch import rand, randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision >>> metric = MulticlassRecallAtFixedPrecision(num_classes=3, min_precision=0.5) >>> values = [] >>> for _ in range(20): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(20, 3).softmax(dim=-1), randint(3, (20,)))[0]) >>> fig_, ax_ = metric.plot(values)
MultilabelRecallAtFixedPrecision¶
- class torchmetrics.classification.MultilabelRecallAtFixedPrecision(num_labels, min_precision, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of either 2 tensors or 2 lists containing:recall
(Tensor
): A 1d tensor of size(n_classes, )
with the maximum recall for the given precision level per classthreshold
(Tensor
): A 1d tensor of size(n_classes, )
with the corresponding threshold level per class
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
`None`
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).- Parameters:
num_labels (
int
) – Integer specifing the number of labelsmin_precision (
float
) – float value specifying minimum precision threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=None) >>> metric(preds, target) (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500])) >>> mlrafp = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5, thresholds=5) >>> mlrafp(preds, target) (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000]))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5) >>> metric.update(rand(20, 3), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision >>> metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5) >>> values = [ ] >>> for _ in range(10): ... # we index by 0 such that only the maximum recall value is plotted ... values.append(metric(rand(20, 3), randint(2, (20, 3)))[0]) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
binary_recall_at_fixed_precision¶
- torchmetrics.functional.classification.binary_recall_at_fixed_precision(preds, target, min_precision, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided for binary tasks.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsmin_precision (
float
) – float value specifying minimum precision threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of 2 tensors containing:
recall: an scalar tensor with the maximum recall for the given precision level
threshold: an scalar tensor with the corresponding threshold level
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import binary_recall_at_fixed_precision >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=None) (tensor(1.), tensor(0.5000)) >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=5) (tensor(1.), tensor(0.5000))
multiclass_recall_at_fixed_precision¶
- torchmetrics.functional.classification.multiclass_recall_at_fixed_precision(preds, target, num_classes, min_precision, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided for multiclass tasks.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesmin_precision (
float
) – float value specifying minimum precision threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 2 tensors or 2 lists containing
recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multiclass_recall_at_fixed_precision >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=None) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) >>> multiclass_recall_at_fixed_precision(preds, target, num_classes=5, min_precision=0.5, thresholds=5) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06]))
multilabel_recall_at_fixed_precision¶
- torchmetrics.functional.classification.multilabel_recall_at_fixed_precision(preds, target, num_labels, min_precision, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the highest possible recall value given the minimum precision thresholds provided for multilabel tasks.
This is done by first calculating the precision-recall curve for different thresholds and the find the recall for a given precision level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to
None
will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsmin_precision (
float
) – float value specifying minimum precision threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to
None
, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.If set to an
int
(larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.If set to an
list
of floats, will use the indicated thresholds in the list as bins for the calculationIf set to an 1d
Tensor
of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 2 tensors or 2 lists containing
recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multilabel_recall_at_fixed_precision >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=None) (tensor([1., 1., 1.]), tensor([0.0500, 0.5500, 0.0500])) >>> multilabel_recall_at_fixed_precision(preds, target, num_labels=3, min_precision=0.5, thresholds=5) (tensor([1., 1., 1.]), tensor([0.0000, 0.5000, 0.0000]))
ROC¶
Module Interface¶
- class torchmetrics.ROC(**kwargs)[source]¶
Compute the Receiver Operating Characteristic (ROC).
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryROC
,MulticlassROC
andMultilabelROC
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> pred = tensor([0.0, 1.0, 2.0, 3.0]) >>> target = tensor([0, 1, 1, 1]) >>> roc = ROC(task="binary") >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000])
>>> pred = tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = tensor([0, 1, 3, 2]) >>> roc = ROC(task="multiclass", num_classes=4) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] >>> thresholds [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500])]
>>> pred = tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], ... [0.8603, 0.0745, 0.1837]]) >>> target = tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) >>> roc = ROC(task='multilabel', num_labels=3) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), tensor([0., 0., 0., 1., 1.]), tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])] >>> tpr [tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] >>> thresholds [tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]), tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])]
BinaryROC¶
- class torchmetrics.classification.BinaryROC(thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the Receiver Operating Characteristic (ROC) for binary tasks.
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of 3 tensors containing:fpr
(Tensor
): A 1d tensor of size(n_thresholds+1, )
with false positive rate valuestpr
(Tensor
): A 1d tensor of size(n_thresholds+1, )
with true positive rate valuesthresholds
(Tensor
): A 1d tensor of size(n_thresholds, )
with decreasing threshold values
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
Note
The outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.
- Parameters:
thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import BinaryROC >>> preds = tensor([0, 0.5, 0.7, 0.8]) >>> target = tensor([0, 1, 1, 0]) >>> metric = BinaryROC(thresholds=None) >>> metric(preds, target) (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000])) >>> broc = BinaryROC(thresholds=5) >>> broc(preds, target) (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), tensor([0., 0., 1., 1., 1.]), tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
- plot(curve=None, score=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
curve (
Optional
[Tuple
[Tensor
,Tensor
,Tensor
]]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.score (
Union
[Tensor
,bool
,None
]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> from torchmetrics.classification import BinaryROC >>> preds = rand(20) >>> target = randint(2, (20,)) >>> metric = BinaryROC() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot(score=True)
MulticlassROC¶
- class torchmetrics.classification.MulticlassROC(num_classes, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the Receiver Operating Characteristic (ROC) for binary tasks.
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of either 3 tensors or 3 lists containingfpr
(Tensor
): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds+1, )
with false positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size(n_classes, n_thresholds+1)
with false positive rate values is returned.tpr
(Tensor
): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds+1, )
with true positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size(n_classes, n_thresholds+1)
with true positive rate values is returned.thresholds
(Tensor
): if thresholds=None a list for each class is returned with an 1d tensor of size(n_thresholds, )
with decreasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size(n_thresholds, )
is returned with shared threshold values for all classes.
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
Note
Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassROC >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassROC(num_classes=5, thresholds=None) >>> fpr, tpr, thresholds = metric(preds, target) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])] >>> thresholds [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])] >>> mcroc = MulticlassROC(num_classes=5, thresholds=5) >>> mcroc(preds, target) (tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), tensor([[0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.], [0., 0., 0., 0., 1.], [0., 0., 0., 0., 1.], [0., 0., 0., 0., 0.]]), tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
- plot(curve=None, score=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
curve (
Union
[Tuple
[Tensor
,Tensor
,Tensor
],Tuple
[List
[Tensor
],List
[Tensor
],List
[Tensor
]],None
]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.score (
Union
[Tensor
,bool
,None
]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn, randint >>> from torchmetrics.classification import MulticlassROC >>> preds = randn(20, 3).softmax(dim=-1) >>> target = randint(3, (20,)) >>> metric = MulticlassROC(num_classes=3) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot(score=True)
MultilabelROC¶
- class torchmetrics.classification.MultilabelROC(num_labels, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the Receiver Operating Characteristic (ROC) for binary tasks.
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns a tuple of either 3 tensors or 3 lists containingfpr
(Tensor
): if thresholds=None a list for each label is returned with an 1d tensor of size(n_thresholds+1, )
with false positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size(n_labels, n_thresholds+1)
with false positive rate values is returned.tpr
(Tensor
): if thresholds=None a list for each label is returned with an 1d tensor of size(n_thresholds+1, )
with true positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size(n_labels, n_thresholds+1)
with true positive rate values is returned.thresholds
(Tensor
): if thresholds=None a list for each label is returned with an 1d tensor of size(n_thresholds, )
with decreasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size(n_thresholds, )
is returned with shared threshold values for all labels.
Note
The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
Note
The outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelROC >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelROC(num_labels=3, thresholds=None) >>> fpr, tpr, thresholds = metric(preds, target) >>> fpr [tensor([0.0000, 0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), tensor([0., 0., 0., 1.])] >>> tpr [tensor([0.0000, 0.5000, 0.5000, 1.0000]), tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), tensor([0.0000, 0.3333, 0.6667, 1.0000])] >>> thresholds [tensor([1.0000, 0.7500, 0.4500, 0.0500]), tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]), tensor([1.0000, 0.7500, 0.3500, 0.0500])] >>> mlroc = MultilabelROC(num_labels=3, thresholds=5) >>> mlroc(preds, target) (tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000], [0.0000, 0.5000, 0.5000, 0.5000, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000], [0.0000, 0.0000, 1.0000, 1.0000, 1.0000], [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]), tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
- plot(curve=None, score=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
curve (
Union
[Tuple
[Tensor
,Tensor
,Tensor
],Tuple
[List
[Tensor
],List
[Tensor
],List
[Tensor
]],None
]) – the output of either metric.compute or metric.forward. If no value is provided, will automatically call metric.compute and plot that result.score (
Union
[Tensor
,bool
,None
]) – Provide a area-under-the-curve score to be displayed on the plot. If True and no curve is provided, will automatically compute the score.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> from torchmetrics.classification import MultilabelROC >>> preds = rand(20, 3) >>> target = randint(2, (20,3)) >>> metric = MultilabelROC(num_labels=3) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot(score=True)
Functional Interface¶
- torchmetrics.functional.roc(preds, target, task, thresholds=None, num_classes=None, num_labels=None, ignore_index=None, validate_args=True)[source]¶
Compute the Receiver Operating Characteristic (ROC).
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_roc()
,multiclass_roc()
andmultilabel_roc()
for the specific details of each argument influence and examples.- Legacy Example:
>>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) >>> target = torch.tensor([0, 1, 1, 1]) >>> fpr, tpr, thresholds = roc(pred, target, task='binary') >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000])
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> fpr, tpr, thresholds = roc(pred, target, task='multiclass', num_classes=4) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] >>> thresholds [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500])]
>>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], ... [0.8603, 0.0745, 0.1837]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) >>> fpr, tpr, thresholds = roc(pred, target, task='multilabel', num_labels=3) >>> fpr [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), tensor([0., 0., 0., 1., 1.]), tensor([0.0000, 0.0000, 0.3333, 0.6667, 1.0000])] >>> tpr [tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] >>> thresholds [tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]), tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])]
binary_roc¶
- torchmetrics.functional.classification.binary_roc(preds, target, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the Receiver Operating Characteristic (ROC) for binary tasks.
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of 3 tensors containing:
fpr: an 1d tensor of size (n_thresholds+1, ) with false positive rate values
tpr: an 1d tensor of size (n_thresholds+1, ) with true positive rate values
thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import binary_roc >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_roc(preds, target, thresholds=None) (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), tensor([1.0000, 0.8000, 0.7000, 0.5000, 0.0000])) >>> binary_roc(preds, target, thresholds=5) (tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), tensor([0., 0., 1., 1., 1.]), tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
multiclass_roc¶
- torchmetrics.functional.classification.multiclass_roc(preds, target, num_classes, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the Receiver Operating Characteristic (ROC) for multiclass tasks.
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 3 tensors or 3 lists containing
fpr: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with false positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with false positive rate values is returned.
tpr: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds+1, ) with true positive rate values (length may differ between classes). If thresholds is set to something else, then a single 2d tensor of size (n_classes, n_thresholds+1) with true positive rate values is returned.
thresholds: if thresholds=None a list for each class is returned with an 1d tensor of size (n_thresholds, ) with decreasing threshold values (length may differ between classes). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes.
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multiclass_roc >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> fpr, tpr, thresholds = multiclass_roc( ... preds, target, num_classes=5, thresholds=None ... ) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000]), tensor([0., 1.])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0., 0.])] >>> thresholds [tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.7500, 0.0500]), tensor([1.0000, 0.0500])] >>> multiclass_roc( ... preds, target, num_classes=5, thresholds=5 ... ) (tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], [0.0000, 0.3333, 0.3333, 0.3333, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), tensor([[0., 1., 1., 1., 1.], [0., 1., 1., 1., 1.], [0., 0., 0., 0., 1.], [0., 0., 0., 0., 1.], [0., 0., 0., 0., 0.]]), tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
multilabel_roc¶
- torchmetrics.functional.classification.multilabel_roc(preds, target, num_labels, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the Receiver Operating Characteristic (ROC) for multilabel tasks.
The curve consist of multiple pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff between the two values can be seen.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing.
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 3 tensors or 3 lists containing
fpr: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with false positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with false positive rate values is returned.
tpr: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds+1, ) with true positive rate values (length may differ between labels). If thresholds is set to something else, then a single 2d tensor of size (n_labels, n_thresholds+1) with true positive rate values is returned.
thresholds: if thresholds=None a list for each label is returned with an 1d tensor of size (n_thresholds, ) with decreasing threshold values (length may differ between labels). If threshold is set to something else, then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels.
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multilabel_roc >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> fpr, tpr, thresholds = multilabel_roc( ... preds, target, num_labels=3, thresholds=None ... ) >>> fpr [tensor([0.0000, 0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 0.5000, 0.5000, 1.0000]), tensor([0., 0., 0., 1.])] >>> tpr [tensor([0.0000, 0.5000, 0.5000, 1.0000]), tensor([0.0000, 0.0000, 0.5000, 1.0000, 1.0000]), tensor([0.0000, 0.3333, 0.6667, 1.0000])] >>> thresholds [tensor([1.0000, 0.7500, 0.4500, 0.0500]), tensor([1.0000, 0.7500, 0.6500, 0.5500, 0.0500]), tensor([1.0000, 0.7500, 0.3500, 0.0500])] >>> multilabel_roc( ... preds, target, num_labels=3, thresholds=5 ... ) (tensor([[0.0000, 0.0000, 0.0000, 0.5000, 1.0000], [0.0000, 0.5000, 0.5000, 0.5000, 1.0000], [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]), tensor([[0.0000, 0.5000, 0.5000, 0.5000, 1.0000], [0.0000, 0.0000, 1.0000, 1.0000, 1.0000], [0.0000, 0.3333, 0.3333, 0.6667, 1.0000]]), tensor([1.0000, 0.7500, 0.5000, 0.2500, 0.0000]))
Specificity¶
Module Interface¶
- class torchmetrics.Specificity(**kwargs)[source]¶
Compute Specificity.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TP} + \text{FP} \neq 0\). If this case is encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may therefore be affected in turn.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinarySpecificity
,MulticlassSpecificity
andMultilabelSpecificity
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> specificity = Specificity(task="multiclass", average='macro', num_classes=3) >>> specificity(preds, target) tensor(0.6111) >>> specificity = Specificity(task="multiclass", average='micro', num_classes=3) >>> specificity(preds, target) tensor(0.6250)
BinarySpecificity¶
- class torchmetrics.classification.BinarySpecificity(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Specificity for binary tasks.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TN} + \text{FP} \neq 0\). If this case is encountered a score of 0 is returned.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:bs
(Tensor
): Ifmultidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinarySpecificity >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinarySpecificity() >>> metric(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinarySpecificity >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinarySpecificity() >>> metric(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.classification import BinarySpecificity >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinarySpecificity(multidim_average='samplewise') >>> metric(preds, target) tensor([0.0000, 0.3333])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinarySpecificity >>> metric = BinarySpecificity() >>> metric.update(rand(10), randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import BinarySpecificity >>> metric = BinarySpecificity() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(rand(10), randint(2,(10,)))) >>> fig_, ax_ = metric.plot(values)
MulticlassSpecificity¶
- class torchmetrics.classification.MulticlassSpecificity(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Specificity for multiclass tasks.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TN} + \text{FP} \neq 0\). If this case is encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:mcs
(Tensor
): The returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassSpecificity >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassSpecificity(num_classes=3) >>> metric(preds, target) tensor(0.8889) >>> mcs = MulticlassSpecificity(num_classes=3, average=None) >>> mcs(preds, target) tensor([1.0000, 0.6667, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassSpecificity >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassSpecificity(num_classes=3) >>> metric(preds, target) tensor(0.8889) >>> mcs = MulticlassSpecificity(num_classes=3, average=None) >>> mcs(preds, target) tensor([1.0000, 0.6667, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassSpecificity >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.7500, 0.6556]) >>> mcs = MulticlassSpecificity(num_classes=3, multidim_average='samplewise', average=None) >>> mcs(preds, target) tensor([[0.7500, 0.7500, 0.7500], [0.8000, 0.6667, 0.5000]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randint >>> # Example plotting a single value per class >>> from torchmetrics.classification import MulticlassSpecificity >>> metric = MulticlassSpecificity(num_classes=3, average=None) >>> metric.update(randint(3, (20,)), randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> from torch import randint >>> # Example plotting a multiple values per class >>> from torchmetrics.classification import MulticlassSpecificity >>> metric = MulticlassSpecificity(num_classes=3, average=None) >>> values = [] >>> for _ in range(20): ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
MultilabelSpecificity¶
- class torchmetrics.classification.MultilabelSpecificity(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute Specificity for multilabel tasks.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respectively. The metric is only proper defined when \(\text{TN} + \text{FP} \neq 0\). If this case is encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be affected in turn.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
As output to
forward
andcompute
the metric returns the following output:mls
(Tensor
): The returned shape depends on theaverage
andmultidim_average
arguments:If
multidim_average
is set toglobal
If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelSpecificity >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelSpecificity(num_labels=3) >>> metric(preds, target) tensor(0.6667) >>> mls = MultilabelSpecificity(num_labels=3, average=None) >>> mls(preds, target) tensor([1., 1., 0.])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelSpecificity >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelSpecificity(num_labels=3) >>> metric(preds, target) tensor(0.6667) >>> mls = MultilabelSpecificity(num_labels=3, average=None) >>> mls(preds, target) tensor([1., 1., 0.])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelSpecificity >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelSpecificity(num_labels=3, multidim_average='samplewise') >>> metric(preds, target) tensor([0.0000, 0.3333]) >>> mls = MultilabelSpecificity(num_labels=3, multidim_average='samplewise', average=None) >>> mls(preds, target) tensor([[0., 0., 0.], [0., 0., 1.]])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import MultilabelSpecificity >>> metric = MultilabelSpecificity(num_labels=3) >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) >>> fig_, ax_ = metric.plot()
>>> from torch import rand, randint >>> # Example plotting multiple values >>> from torchmetrics.classification import MultilabelSpecificity >>> metric = MultilabelSpecificity(num_labels=3) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.specificity(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute Specificity. :rtype:
Tensor
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_specificity()
,multiclass_specificity()
andmultilabel_specificity()
for the specific details of each argument influence and examples.- LegacyExample:
>>> from torch import tensor >>> preds = tensor([2, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> specificity(preds, target, task="multiclass", average='macro', num_classes=3) tensor(0.6111) >>> specificity(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.6250)
binary_specificity¶
- torchmetrics.functional.classification.binary_specificity(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Specificity for binary tasks.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
If
multidim_average
is set toglobal
, the metric returns a scalar value. Ifmultidim_average
is set tosamplewise
, the metric returns(N,)
vector consisting of a scalar value per sample.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_specificity >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_specificity(preds, target) tensor(0.6667)
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_specificity >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_specificity(preds, target) tensor(0.6667)
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_specificity >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_specificity(preds, target, multidim_average='samplewise') tensor([0.0000, 0.3333])
multiclass_specificity¶
- torchmetrics.functional.classification.multiclass_specificity(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Specificity for multiclass tasks.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_specificity >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_specificity(preds, target, num_classes=3) tensor(0.8889) >>> multiclass_specificity(preds, target, num_classes=3, average=None) tensor([1.0000, 0.6667, 1.0000])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_specificity >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_specificity(preds, target, num_classes=3) tensor(0.8889) >>> multiclass_specificity(preds, target, num_classes=3, average=None) tensor([1.0000, 0.6667, 1.0000])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_specificity >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise') tensor([0.7500, 0.6556]) >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[0.7500, 0.7500, 0.7500], [0.8000, 0.6667, 0.5000]])
multilabel_specificity¶
- torchmetrics.functional.classification.multilabel_specificity(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute Specificity for multilabel tasks.
\[\text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}}\]Where \(\text{TN}\) and \(\text{FP}\) represent the number of true negatives and false positives respecitively.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the output will be a scalar tensorIf
average=None/'none'
, the shape will be(C,)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N,)
If
average=None/'none'
, the shape will be(N, C)
- Return type:
The returned shape depends on the
average
andmultidim_average
arguments
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_specificity >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_specificity(preds, target, num_labels=3) tensor(0.6667) >>> multilabel_specificity(preds, target, num_labels=3, average=None) tensor([1., 1., 0.])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_specificity >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_specificity(preds, target, num_labels=3) tensor(0.6667) >>> multilabel_specificity(preds, target, num_labels=3, average=None) tensor([1., 1., 0.])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_specificity >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise') tensor([0.0000, 0.3333]) >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[0., 0., 0.], [0., 0., 1.]])
Specificity At Sensitivity¶
Module Interface¶
- class torchmetrics.SpecificityAtSensitivity(**kwargs)[source]¶
Compute the higest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinarySpecificityAtSensitivity
,MulticlassSpecificityAtSensitivity
andMultilabelSpecificityAtSensitivity
for the specific details of each argument influence and examples.
BinarySpecificityAtSensitivity¶
- class torchmetrics.classification.BinarySpecificityAtSensitivity(min_sensitivity, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the higest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
min_sensitivity (
float
) – float value specifying minimum sensitivity threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
a tuple of 2 tensors containing:
specificity: an scalar tensor with the maximum specificity for the given sensitivity level
threshold: an scalar tensor with the corresponding threshold level
- Return type:
(tuple)
Example
>>> from torchmetrics.classification import BinarySpecificityAtSensitivity >>> from torch import tensor >>> preds = tensor([0, 0.5, 0.4, 0.1]) >>> target = tensor([0, 1, 1, 1]) >>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=None) >>> metric(preds, target) (tensor(1.), tensor(0.4000)) >>> metric = BinarySpecificityAtSensitivity(min_sensitivity=0.5, thresholds=5) >>> metric(preds, target) (tensor(1.), tensor(0.2500))
MulticlassSpecificityAtSensitivity¶
- class torchmetrics.classification.MulticlassSpecificityAtSensitivity(num_classes, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the higest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
For multiclass the metric is calculated by iteratively treating each class as the positive class and all other classes as the negative, which is refered to as the one-vs-rest approach. One-vs-one is currently not supported by this metric.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
num_classes (
int
) – Integer specifing the number of classesmin_sensitivity (
float
) – float value specifying minimum sensitivity threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
a tuple of either 2 tensors or 2 lists containing
- specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given
sensitivity level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.classification import MulticlassSpecificityAtSensitivity >>> from torch import tensor >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = tensor([0, 1, 3, 2]) >>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=None) >>> metric(preds, target) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) >>> metric = MulticlassSpecificityAtSensitivity(num_classes=5, min_sensitivity=0.5, thresholds=5) >>> metric(preds, target) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
MultilabelSpecificityAtSensitivity¶
- class torchmetrics.classification.MultilabelSpecificityAtSensitivity(num_labels, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute the higest possible specificity value given the minimum sensitivity thresholds provided.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsmin_sensitivity (
float
) – float value specifying minimum sensitivity threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
validate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
a tuple of either 2 tensors or 2 lists containing
- specificity: an 1d tensor of size (n_classes, ) with the maximum specificity for the given
sensitivity level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.classification import MultilabelSpecificityAtSensitivity >>> from torch import tensor >>> preds = tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=None) >>> metric(preds, target) (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500])) >>> metric = MultilabelSpecificityAtSensitivity(num_labels=3, min_sensitivity=0.5, thresholds=5) >>> metric(preds, target) (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
Functional Interface¶
binary_specificity_at_sensitivity¶
- torchmetrics.functional.classification.binary_specificity_at_sensitivity(preds, target, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the higest possible specificity value given the minimum sensitivity levels provided for binary tasks.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsmin_sensitivity (
float
) – float value specifying minimum sensitivity threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of 2 tensors containing:
specificity: a scalar tensor with the maximum specificity for the given sensitivity level
threshold: a scalar tensor with the corresponding threshold level
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import binary_specificity_at_sensitivity >>> preds = torch.tensor([0, 0.5, 0.4, 0.1]) >>> target = torch.tensor([0, 1, 1, 1]) >>> binary_specificity_at_sensitivity(preds, target, min_sensitivity=0.5, thresholds=None) (tensor(1.), tensor(0.4000)) >>> binary_specificity_at_sensitivity(preds, target, min_sensitivity=0.5, thresholds=5) (tensor(1.), tensor(0.2500))
multiclass_specificity_at_sensitivity¶
- torchmetrics.functional.classification.multiclass_specificity_at_sensitivity(preds, target, num_classes, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the higest possible specificity value given the minimum sensitivity level provided for multiclass tasks.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{classes})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesmin_sensitivity (
float
) – float value specifying minimum sensitivity threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 2 tensors or 2 lists containing
recall: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multiclass_specificity_at_sensitivity >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> multiclass_specificity_at_sensitivity(preds, target, num_classes=5, min_sensitivity=0.5, thresholds=None) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 5.0000e-02, 5.0000e-02, 1.0000e+06])) >>> multiclass_specificity_at_sensitivity(preds, target, num_classes=5, min_sensitivity=0.5, thresholds=5) (tensor([1., 1., 0., 0., 0.]), tensor([7.5000e-01, 7.5000e-01, 0.0000e+00, 0.0000e+00, 1.0000e+06]))
multilabel_specificity_at_sensitivity¶
- torchmetrics.functional.classification.multilabel_specificity_at_sensitivity(preds, target, num_labels, min_sensitivity, thresholds=None, ignore_index=None, validate_args=True)[source]¶
Compute the higest possible specificity value given the minimum sensitivity level provided for multilabel tasks.
This is done by first calculating the Receiver Operating Characteristic (ROC) curve for different thresholds and the find the specificity for a given sensitivity level.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, C, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size \(\mathcal{O}(n_{samples})\) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size \(\mathcal{O}(n_{thresholds} \times n_{labels})\) (constant memory).
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsmin_sensitivity (
float
) – float value specifying minimum sensitivity threshold.thresholds (
Union
[int
,List
[float
],Tensor
,None
]) –Can be one of:
If set to None, will use a non-binned approach where thresholds are dynamically calculated from all the data. Most accurate but also most memory consuming approach.
If set to an int (larger than 1), will use that number of thresholds linearly spaced from 0 to 1 as bins for the calculation.
If set to an list of floats, will use the indicated thresholds in the list as bins for the calculation
If set to an 1d tensor of floats, will use the indicated thresholds in the tensor as bins for the calculation.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Returns:
a tuple of either 2 tensors or 2 lists containing
- specificity: an 1d tensor of size (n_classes, ) with the maximum recall for the given precision
level per class
thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class
- Return type:
(tuple)
Example
>>> from torchmetrics.functional.classification import multilabel_specificity_at_sensitivity >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], ... [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], ... [0, 0, 0], ... [0, 1, 1], ... [1, 1, 1]]) >>> multilabel_specificity_at_sensitivity(preds, target, num_labels=3, min_sensitivity=0.5, thresholds=None) (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.6500, 0.3500])) >>> multilabel_specificity_at_sensitivity(preds, target, num_labels=3, min_sensitivity=0.5, thresholds=5) (tensor([1.0000, 0.5000, 1.0000]), tensor([0.7500, 0.5000, 0.2500]))
Stat Scores¶
Module Interface¶
- class torchmetrics.StatScores(**kwargs)[source]¶
Compute the number of true positives, false positives, true negatives, false negatives and the support.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofBinaryStatScores
,MulticlassStatScores
andMultilabelStatScores
for the specific details of each argument influence and examples.- Legacy Example:
>>> from torch import tensor >>> preds = tensor([1, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> stat_scores = StatScores(task="multiclass", num_classes=3, average='micro') >>> stat_scores(preds, target) tensor([2, 2, 6, 2, 4]) >>> stat_scores = StatScores(task="multiclass", num_classes=3, average=None) >>> stat_scores(preds, target) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]])
BinaryStatScores¶
- class torchmetrics.classification.BinaryStatScores(threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute true positives, false positives, true negatives, false negatives and the support for binary tasks.
Related to Type I and Type II errors.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:bss
(Tensor
): A tensor of shape(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends on themultidim_average
parameter:If
multidim_average
is set toglobal
, the shape will be(5,)
If
multidim_average
is set tosamplewise
, the shape will be(N, 5)
- Parameters:
threshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import BinaryStatScores >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryStatScores() >>> metric(preds, target) tensor([2, 1, 2, 1, 3])
- Example (preds is float tensor):
>>> from torchmetrics.classification import BinaryStatScores >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryStatScores() >>> metric(preds, target) tensor([2, 1, 2, 1, 3])
- Example (multidim tensors):
>>> from torchmetrics.classification import BinaryStatScores >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = BinaryStatScores(multidim_average='samplewise') >>> metric(preds, target) tensor([[2, 3, 0, 1, 3], [0, 2, 1, 3, 3]])
MulticlassStatScores¶
- class torchmetrics.classification.MulticlassStatScores(num_classes, top_k=1, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Computes true positives, false positives, true negatives, false negatives and the support for multiclass tasks.
Related to Type I and Type II errors.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int tensor of shape(N, ...)
or float tensor of shape(N, C, ..)
. If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(Tensor
): An int tensor of shape(N, ...)
As output to
forward
andcompute
the metric returns the following output:mcss
(Tensor
): A tensor of shape(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends onaverage
andmultidim_average
parameters:If
multidim_average
is set toglobal
If
average='micro'/'macro'/'weighted'
, the shape will be(5,)
If
average=None/'none'
, the shape will be(C, 5)
If
multidim_average
is set tosamplewise
If
average='micro'/'macro'/'weighted'
, the shape will be(N, 5)
If
average=None/'none'
, the shape will be(N, C, 5)
- Parameters:
num_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MulticlassStatScores >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassStatScores(num_classes=3, average='micro') >>> metric(preds, target) tensor([3, 1, 7, 1, 4]) >>> mcss = MulticlassStatScores(num_classes=3, average=None) >>> mcss(preds, target) tensor([[1, 0, 2, 1, 2], [1, 1, 2, 0, 1], [1, 0, 3, 0, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MulticlassStatScores >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassStatScores(num_classes=3, average='micro') >>> metric(preds, target) tensor([3, 1, 7, 1, 4]) >>> mcss = MulticlassStatScores(num_classes=3, average=None) >>> mcss(preds, target) tensor([[1, 0, 2, 1, 2], [1, 1, 2, 0, 1], [1, 0, 3, 0, 1]])
- Example (multidim tensors):
>>> from torchmetrics.classification import MulticlassStatScores >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro') >>> metric(preds, target) tensor([[3, 3, 9, 3, 6], [2, 4, 8, 4, 6]]) >>> mcss = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average=None) >>> mcss(preds, target) tensor([[[2, 1, 3, 0, 2], [0, 1, 3, 2, 2], [1, 1, 3, 1, 2]], [[0, 1, 4, 1, 1], [1, 1, 2, 2, 3], [1, 2, 2, 1, 2]]])
MultilabelStatScores¶
- class torchmetrics.classification.MultilabelStatScores(num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True, **kwargs)[source]¶
Compute true positives, false positives, true negatives, false negatives and the support for multilabel tasks.
Related to Type I and Type II errors.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): An int or float tensor of shape(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(Tensor
): An int tensor of shape(N, C, ...)
As output to
forward
andcompute
the metric returns the following output:mlss
(Tensor
): A tensor of shape(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends onaverage
andmultidim_average
parameters:If
multidim_average
is set toglobal
If
average='micro'/'macro'/'weighted'
, the shape will be(5,)
If
average=None/'none'
, the shape will be(C, 5)
If
multidim_average
is set tosamplewise
If
average='micro'/'macro'/'weighted'
, the shape will be(N, 5)
If
average=None/'none'
, the shape will be(N, C, 5)
- Parameters:
num_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.classification import MultilabelStatScores >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelStatScores(num_labels=3, average='micro') >>> metric(preds, target) tensor([2, 1, 2, 1, 3]) >>> mlss = MultilabelStatScores(num_labels=3, average=None) >>> mlss(preds, target) tensor([[1, 0, 1, 0, 1], [0, 0, 1, 1, 1], [1, 1, 0, 0, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.classification import MultilabelStatScores >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelStatScores(num_labels=3, average='micro') >>> metric(preds, target) tensor([2, 1, 2, 1, 3]) >>> mlss = MultilabelStatScores(num_labels=3, average=None) >>> mlss(preds, target) tensor([[1, 0, 1, 0, 1], [0, 0, 1, 1, 1], [1, 1, 0, 0, 1]])
- Example (multidim tensors):
>>> from torchmetrics.classification import MultilabelStatScores >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> metric = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average='micro') >>> metric(preds, target) tensor([[2, 3, 0, 1, 3], [0, 2, 1, 3, 3]]) >>> mlss = MultilabelStatScores(num_labels=3, multidim_average='samplewise', average=None) >>> mlss(preds, target) tensor([[[1, 1, 0, 0, 1], [1, 1, 0, 0, 1], [0, 1, 0, 1, 1]], [[0, 0, 0, 2, 2], [0, 2, 0, 0, 0], [0, 0, 1, 1, 1]]])
Functional Interface¶
stat_scores¶
- torchmetrics.functional.stat_scores(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, average='micro', multidim_average='global', top_k=1, ignore_index=None, validate_args=True)[source]¶
Compute the number of true positives, false positives, true negatives, false negatives and the support.
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
,'multiclass'
ormultilabel
. See the documentation ofbinary_stat_scores()
,multiclass_stat_scores()
andmultilabel_stat_scores()
for the specific details of each argument influence and examples.- Return type:
- Legacy Example:
>>> from torch import tensor >>> preds = tensor([1, 0, 2, 1]) >>> target = tensor([1, 1, 2, 0]) >>> stat_scores(preds, target, task='multiclass', num_classes=3, average='micro') tensor([2, 2, 6, 2, 4]) >>> stat_scores(preds, target, task='multiclass', num_classes=3, average=None) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]])
binary_stat_scores¶
- torchmetrics.functional.classification.binary_stat_scores(preds, target, threshold=0.5, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute the true positives, false positives, true negatives, false negatives, support for binary tasks.
Related to Type I and Type II errors.
Accepts the following input tensors:
preds
(int or float tensor):(N, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsthreshold (
float
) – Threshold for transforming probability to binary {0,1} predictionsmultidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
The metric returns a tensor of shape
(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends on themultidim_average
parameter:If
multidim_average
is set toglobal
, the shape will be(5,)
If
multidim_average
is set tosamplewise
, the shape will be(N, 5)
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import binary_stat_scores >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0, 0, 1, 1, 0, 1]) >>> binary_stat_scores(preds, target) tensor([2, 1, 2, 1, 3])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import binary_stat_scores >>> target = tensor([0, 1, 0, 1, 0, 1]) >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_stat_scores(preds, target) tensor([2, 1, 2, 1, 3])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import binary_stat_scores >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> binary_stat_scores(preds, target, multidim_average='samplewise') tensor([[2, 3, 0, 1, 3], [0, 2, 1, 3, 3]])
multiclass_stat_scores¶
- torchmetrics.functional.classification.multiclass_stat_scores(preds, target, num_classes, average='macro', top_k=1, multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute the true positives, false positives, true negatives, false negatives and support for multiclass tasks.
Related to Type I and Type II errors.
Accepts the following input tensors:
preds
:(N, ...)
(int tensor) or(N, C, ..)
(float tensor). If preds is a floating point we applytorch.argmax
along theC
dimension to automatically convert probabilities/logits into an int tensor.target
(int tensor):(N, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_classes (
int
) – Integer specifing the number of classesaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
top_k (
int
) – Number of highest probability or logit score predictions considered to find the correct label. Only works whenpreds
contain probabilities/logits.multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
The metric returns a tensor of shape
(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends onaverage
andmultidim_average
parameters:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the shape will be(5,)
If
average=None/'none'
, the shape will be(C, 5)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N, 5)
If
average=None/'none'
, the shape will be(N, C, 5)
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multiclass_stat_scores >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([2, 1, 0, 1]) >>> multiclass_stat_scores(preds, target, num_classes=3, average='micro') tensor([3, 1, 7, 1, 4]) >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) tensor([[1, 0, 2, 1, 2], [1, 1, 2, 0, 1], [1, 0, 3, 0, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multiclass_stat_scores >>> target = tensor([2, 1, 0, 0]) >>> preds = tensor([[0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], ... [0.05, 0.82, 0.13]]) >>> multiclass_stat_scores(preds, target, num_classes=3, average='micro') tensor([3, 1, 7, 1, 4]) >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) tensor([[1, 0, 2, 1, 2], [1, 1, 2, 0, 1], [1, 0, 3, 0, 1]])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multiclass_stat_scores >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro') tensor([[3, 3, 9, 3, 6], [2, 4, 8, 4, 6]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[[2, 1, 3, 0, 2], [0, 1, 3, 2, 2], [1, 1, 3, 1, 2]], [[0, 1, 4, 1, 1], [1, 1, 2, 2, 3], [1, 2, 2, 1, 2]]])
multilabel_stat_scores¶
- torchmetrics.functional.classification.multilabel_stat_scores(preds, target, num_labels, threshold=0.5, average='macro', multidim_average='global', ignore_index=None, validate_args=True)[source]¶
Compute the true positives, false positives, true negatives, false negatives and support for multilabel tasks.
Related to Type I and Type II errors.
Accepts the following input tensors:
preds
(int or float tensor):(N, C, ...)
. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value inthreshold
.target
(int tensor):(N, C, ...)
- Parameters:
preds (
Tensor
) – Tensor with predictionstarget (
Tensor
) – Tensor with true labelsnum_labels (
int
) – Integer specifing the number of labelsthreshold (
float
) – Threshold for transforming probability to binary (0,1) predictionsaverage (
Optional
[Literal
['micro'
,'macro'
,'weighted'
,'none'
]]) –Defines the reduction that is applied over labels. Should be one of the following:
micro
: Sum statistics over all labelsmacro
: Calculate statistics for each label and average themweighted
: calculates statistics for each label and computes weighted average using their support"none"
orNone
: calculates statistic for each label and applies no reduction
multidim_average (
Literal
['global'
,'samplewise'
]) –Defines how additionally dimensions
...
should be handled. Should be one of the following:global
: Additional dimensions are flatted along the batch dimensionsamplewise
: Statistic will be calculated independently for each sample on theN
axis. The statistics in this case are calculated over the additional dimensions.
ignore_index (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
- Return type:
- Returns:
The metric returns a tensor of shape
(..., 5)
, where the last dimension corresponds to[tp, fp, tn, fn, sup]
(sup
stands for support and equalstp + fn
). The shape depends onaverage
andmultidim_average
parameters:If
multidim_average
is set toglobal
:If
average='micro'/'macro'/'weighted'
, the shape will be(5,)
If
average=None/'none'
, the shape will be(C, 5)
If
multidim_average
is set tosamplewise
:If
average='micro'/'macro'/'weighted'
, the shape will be(N, 5)
If
average=None/'none'
, the shape will be(N, C, 5)
- Example (preds is int tensor):
>>> from torch import tensor >>> from torchmetrics.functional.classification import multilabel_stat_scores >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_stat_scores(preds, target, num_labels=3, average='micro') tensor([2, 1, 2, 1, 3]) >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) tensor([[1, 0, 1, 0, 1], [0, 0, 1, 1, 1], [1, 1, 0, 0, 1]])
- Example (preds is float tensor):
>>> from torchmetrics.functional.classification import multilabel_stat_scores >>> target = tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_stat_scores(preds, target, num_labels=3, average='micro') tensor([2, 1, 2, 1, 3]) >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) tensor([[1, 0, 1, 0, 1], [0, 0, 1, 1, 1], [1, 1, 0, 0, 1]])
- Example (multidim tensors):
>>> from torchmetrics.functional.classification import multilabel_stat_scores >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average='micro') tensor([[2, 3, 0, 1, 3], [0, 2, 1, 3, 3]]) >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[[1, 1, 0, 0, 1], [1, 1, 0, 0, 1], [0, 1, 0, 1, 1]], [[0, 0, 0, 2, 2], [0, 2, 0, 0, 0], [0, 0, 1, 1, 1]]])
Complete Intersection Over Union (cIoU)¶
Module Interface¶
- class torchmetrics.detection.ciou.CompleteIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]¶
Computes Complete Intersection Over Union (CIoU).
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
detection boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed detection classes for the boxes.
target
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
ground truth boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed detection classes for the boxes.
As output of
forward
andcompute
the metric returns the following output:ciou_dict
: A dictionary containing the following key-values:
- Parameters:
box_format (
str
) – Input format of given boxes. Supported formats are[`xyxy`, `xywh`, `cxcywh`]
.iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.
class_metrics (
bool
) – Option to enable per-class metrics for IoU. Has a performance impact.respect_labels (
bool
) – Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou between all pairs of boxes.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.detection import CompleteIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = CompleteIntersectionOverUnion() >>> metric(preds, target) {'ciou': tensor(0.8611)}
- Raises:
ModuleNotFoundError – If torchvision is not installed with version 0.13.0 or newer.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting single value >>> import torch >>> from torchmetrics.detection import CompleteIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = CompleteIntersectionOverUnion() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.detection import CompleteIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = lambda : [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = CompleteIntersectionOverUnion() >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target())) >>> fig_, ax_ = metric.plot(vals)
Functional Interface¶
- torchmetrics.functional.detection.ciou.complete_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]¶
Compute Complete Intersection over Union (CIOU) between two sets of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.
- Parameters:
preds (
Tensor
) – The input tensor containing the predicted bounding boxes.target (
Tensor
) – The tensor containing the ground truth.iou_threshold (
Optional
[float
]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.replacement_val (
float
) – Value to replace values under the threshold with.aggregate (
bool
) – Return the average value instead of the full matrix of values
- Return type:
- Example::
By default iou is aggregated across all box pairs e.g. mean along the diagonal of the IoU matrix:
>>> import torch >>> from torchmetrics.functional.detection import complete_intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> complete_intersection_over_union(preds, target) tensor(0.5790)
- Example::
By setting aggregate=False the IoU score per prediction and target boxes is returned:
>>> import torch >>> from torchmetrics.functional.detection import complete_intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> complete_intersection_over_union(preds, target, aggregate=False) tensor([[ 0.6883, -0.2072, -0.3352], [-0.2217, 0.4881, -0.1913], [-0.3971, -0.1543, 0.5606]])
Distance Intersection Over Union (dIoU)¶
Module Interface¶
- class torchmetrics.detection.diou.DistanceIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]¶
Computes Distance Intersection Over Union (DIoU).
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
detection boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed detection classes for the boxes.
target
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
ground truth boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed ground truth classes for the boxes.
As output of
forward
andcompute
the metric returns the following output:diou_dict
: A dictionary containing the following key-values:
- Parameters:
box_format (
str
) – Input format of given boxes. Supported formats are['xyxy', 'xywh', 'cxcywh']
.iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.
class_metrics (
bool
) – Option to enable per-class metrics for IoU. Has a performance impact.respect_labels (
bool
) – Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou between all pairs of boxes.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.detection import DistanceIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = DistanceIntersectionOverUnion() >>> metric(preds, target) {'diou': tensor(0.8611)}
- Raises:
ModuleNotFoundError – If torchvision is not installed with version 0.13.0 or newer.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting single value >>> import torch >>> from torchmetrics.detection import DistanceIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = DistanceIntersectionOverUnion() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.detection import DistanceIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = lambda : [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = DistanceIntersectionOverUnion() >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target())) >>> fig_, ax_ = metric.plot(vals)
Functional Interface¶
- torchmetrics.functional.detection.diou.distance_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]¶
Compute Distance Intersection over Union (DIOU) between two sets of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.
- Parameters:
preds (
Tensor
) – The input tensor containing the predicted bounding boxes.target (
Tensor
) – The tensor containing the ground truth.iou_threshold (
Optional
[float
]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.replacement_val (
float
) – Value to replace values under the threshold with.aggregate (
bool
) – Return the average value instead of the full matrix of values
- Return type:
- Example::
By default diou is aggregated across all box pairs e.g. mean along the diagonal of the dIoU matrix:
>>> import torch >>> from torchmetrics.functional.detection import distance_intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> distance_intersection_over_union(preds, target) tensor(0.5793)
- Example::
By setting aggregate=False the IoU score per prediction and target boxes is returned:
>>> import torch >>> from torchmetrics.functional.detection import distance_intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> distance_intersection_over_union(preds, target, aggregate=False) tensor([[ 0.6883, -0.2043, -0.3351], [-0.2214, 0.4886, -0.1913], [-0.3971, -0.1510, 0.5609]])
Generalized Intersection Over Union (gIoU)¶
Module Interface¶
- class torchmetrics.detection.giou.GeneralizedIntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]¶
Compute Generalized Intersection Over Union (GIoU).
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
detection boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed detection classes for the boxes.
target
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
ground truth boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed ground truth classes for the boxes.
As output of
forward
andcompute
the metric returns the following output:giou_dict
: A dictionary containing the following key-values:
- Parameters:
box_format (
str
) – Input format of given boxes. Supported formats are[`xyxy`, `xywh`, `cxcywh`]
.iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.
class_metrics (
bool
) – Option to enable per-class metrics for IoU. Has a performance impact.respect_labels (
bool
) – Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou between all pairs of boxes.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.detection import GeneralizedIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = GeneralizedIntersectionOverUnion() >>> metric(preds, target) {'giou': tensor(0.8613)}
- Raises:
ModuleNotFoundError – If torchvision is not installed with version 0.8.0 or newer.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting single value >>> import torch >>> from torchmetrics.detection import GeneralizedIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = GeneralizedIntersectionOverUnion() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.detection import GeneralizedIntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = lambda : [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = GeneralizedIntersectionOverUnion() >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target())) >>> fig_, ax_ = metric.plot(vals)
Functional Interface¶
- torchmetrics.functional.detection.giou.generalized_intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]¶
Compute Generalized Intersection over Union (GIOU) between two sets of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.
- Parameters:
preds (
Tensor
) – The input tensor containing the predicted bounding boxes.target (
Tensor
) – The tensor containing the ground truth.iou_threshold (
Optional
[float
]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.replacement_val (
float
) – Value to replace values under the threshold with.aggregate (
bool
) – Return the average value instead of the full matrix of values
- Return type:
- Example::
By default giou is aggregated across all box pairs e.g. mean along the diagonal of the gIoU matrix:
>>> import torch >>> from torchmetrics.functional.detection import generalized_intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> generalized_intersection_over_union(preds, target) tensor(0.5638)
- Example::
By setting aggregate=False the full IoU matrix is returned:
>>> import torch >>> from torchmetrics.functional.detection import generalized_intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> generalized_intersection_over_union(preds, target, aggregate=False) tensor([[ 0.6895, -0.4964, -0.4944], [-0.5105, 0.4673, -0.3434], [-0.6024, -0.4021, 0.5345]])
Intersection Over Union (IoU)¶
Module Interface¶
- class torchmetrics.detection.iou.IntersectionOverUnion(box_format='xyxy', iou_threshold=None, class_metrics=False, respect_labels=True, **kwargs)[source]¶
Computes Intersection Over Union (IoU).
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
detection boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels:
IntTensor
of shape(num_boxes)
containing 0-indexed detection classes for the boxes.
target
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
ground truth boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed ground truth classes for the boxes.
As output of
forward
andcompute
the metric returns the following output:iou_dict
: A dictionary containing the following key-values:
- Parameters:
box_format (
str
) – Input format of given boxes. Supported formats are[`xyxy`, `xywh`, `cxcywh`]
.iou_thresholds – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.
class_metrics (
bool
) – Option to enable per-class metrics for IoU. Has a performance impact.respect_labels (
bool
) –- Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example:
>>> import torch >>> from torchmetrics.detection import IntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([ ... [296.55, 93.96, 314.97, 152.79], ... [298.55, 98.96, 314.97, 151.79]]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = IntersectionOverUnion() >>> metric(preds, target) {'iou': tensor(0.8614)}
Example:
The metric can also return the score per class: >>> import torch >>> from torchmetrics.detection import IntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([ ... [296.55, 93.96, 314.97, 152.79], ... [298.55, 98.96, 314.97, 151.79]]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([ ... [300.00, 100.00, 315.00, 150.00], ... [300.00, 100.00, 315.00, 150.00] ... ]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> metric = IntersectionOverUnion(class_metrics=True) >>> metric(preds, target) {'iou': tensor(0.7756), 'iou/cl_4': tensor(0.6898), 'iou/cl_5': tensor(0.8614)}
- Raises:
ModuleNotFoundError – If torchvision is not installed with version 0.8.0 or newer.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.detection import IntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = IntersectionOverUnion() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.detection import IntersectionOverUnion >>> preds = [ ... { ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), ... "scores": torch.tensor([0.236, 0.56]), ... "labels": torch.tensor([4, 5]), ... } ... ] >>> target = lambda : [ ... { ... "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), ... "labels": torch.tensor([5]), ... } ... ] >>> metric = IntersectionOverUnion() >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target())) >>> fig_, ax_ = metric.plot(vals)
Functional Interface¶
- torchmetrics.functional.detection.iou.intersection_over_union(preds, target, iou_threshold=None, replacement_val=0, aggregate=True)[source]¶
Compute Intersection over Union between two sets of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format with 0 <= x1 < x2 and 0 <= y1 < y2.
- Parameters:
preds (
Tensor
) – The input tensor containing the predicted bounding boxes.target (
Tensor
) – The tensor containing the ground truth.iou_threshold (
Optional
[float
]) – Optional IoU thresholds for evaluation. If set to None the threshold is ignored.replacement_val (
float
) – Value to replace values under the threshold with.aggregate (
bool
) – Return the average value instead of the full matrix of values
- Return type:
- Example::
By default iou is aggregated across all box pairs e.g. mean along the diagonal of the IoU matrix:
>>> import torch >>> from torchmetrics.functional.detection import intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> intersection_over_union(preds, target) tensor(0.5879)
- Example::
By setting aggregate=False the full IoU matrix is returned:
>>> import torch >>> from torchmetrics.functional.detection import intersection_over_union >>> preds = torch.tensor( ... [ ... [296.55, 93.96, 314.97, 152.79], ... [328.94, 97.05, 342.49, 122.98], ... [356.62, 95.47, 372.33, 147.55], ... ] ... ) >>> target = torch.tensor( ... [ ... [300.00, 100.00, 315.00, 150.00], ... [330.00, 100.00, 350.00, 125.00], ... [350.00, 100.00, 375.00, 150.00], ... ] ... ) >>> intersection_over_union(preds, target, aggregate=False) tensor([[0.6898, 0.0000, 0.0000], [0.0000, 0.5086, 0.0000], [0.0000, 0.0000, 0.5654]])
Mean-Average-Precision (mAP)¶
Module Interface¶
- class torchmetrics.detection.mean_ap.MeanAveragePrecision(box_format='xyxy', iou_type='bbox', iou_thresholds=None, rec_thresholds=None, max_detection_thresholds=None, class_metrics=False, extended_summary=False, average='macro', **kwargs)[source]¶
Compute the Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR) for object detection predictions.
\[\text{mAP} = \frac{1}{n} \sum_{i=1}^{n} AP_i\]where \(AP_i\) is the average precision for class \(i\) and \(n\) is the number of classes. The average precision is defined as the area under the precision-recall curve. For object detection the recall and precision are defined based on the intersection of union (IoU) between the predicted bounding boxes and the ground truth bounding boxes e.g. if two boxes have an IoU > t (with t being some threshold) they are considered a match and therefore considered a true positive. The precision is then defined as the number of true positives divided by the number of all detected boxes and the recall is defined as the number of true positives divided by the number of all ground boxes.
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dictboxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
detection boxes of the format specified in the constructor. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates, but can be changed using thebox_format
parameter. Only required when iou_type=”bbox”.scores
(Tensor
): float tensor of shape(num_boxes)
containing detection scores for the boxes.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed detection classes for the boxes.masks
(Tensor
): boolean tensor of shape(num_boxes, image_height, image_width)
containing boolean masks. Only required when iou_type=”segm”.
target
(List
): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict:boxes
(Tensor
): float tensor of shape(num_boxes, 4)
containingnum_boxes
ground truth boxes of the format specified in the constructor. only required when iou_type=”bbox”. By default, this method expects(xmin, ymin, xmax, ymax)
in absolute image coordinates.labels
(Tensor
): integer tensor of shape(num_boxes)
containing 0-indexed ground truth classes for the boxes.masks
(Tensor
): boolean tensor of shape(num_boxes, image_height, image_width)
containing boolean masks. Only required when iou_type=”segm”.iscrowd
(Tensor
): integer tensor of shape(num_boxes)
containing 0/1 values indicating whether the bounding box/masks indicate a crowd of objects. Value is optional, and if not provided it will automatically be set to 0.area
(Tensor
): float tensor of shape(num_boxes)
containing the area of the object. Value is optional, and if not provided will be automatically calculated based on the bounding box/masks provided. Only affects which samples contribute to the map_small, map_medium, map_large values
As output of
forward
andcompute
the metric returns the following output:map_dict
: A dictionary containing the following key-values:map: (
Tensor
), global mean average precisionmap_small: (
Tensor
), mean average precision for small objectsmap_medium:(
Tensor
), mean average precision for medium objectsmap_large: (
Tensor
), mean average precision for large objectsmar_1: (
Tensor
), mean average recall for 1 detection per imagemar_10: (
Tensor
), mean average recall for 10 detections per imagemar_100: (
Tensor
), mean average recall for 100 detections per imagemar_small: (
Tensor
), mean average recall for small objectsmar_medium: (
Tensor
), mean average recall for medium objectsmar_large: (
Tensor
), mean average recall for large objectsmap_50: (
Tensor
) (-1 if 0.5 not in the list of iou thresholds), mean average precision at IoU=0.50map_75: (
Tensor
) (-1 if 0.75 not in the list of iou thresholds), mean average precision at IoU=0.75map_per_class: (
Tensor
) (-1 if class metrics are disabled), mean average precision per observed classmar_100_per_class: (
Tensor
) (-1 if class metrics are disabled), mean average recall for 100 detections per image per observed classclasses (
Tensor
), list of all observed classes
For an example on how to use this metric check the torchmetrics mAP example.
Note
map
score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well.Note
This metric utilizes the official pycocotools implementation as its backend. This means that the metric requires you to have pycocotools installed. In addition we require torchvision version 0.8.0 or newer. Please install with
pip install torchmetrics[detection]
.- Parameters:
box_format (
Literal
['xyxy'
,'xywh'
,'cxcywh'
]) –Input format of given boxes. Supported formats are:
’xyxy’: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
’xywh’ : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. This is the default format used by pycoco and all input formats will be converted to this.
’cxcywh’: boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height.
iou_type (
Union
[Literal
['bbox'
,'segm'
],Tuple
[str
]]) – Type of input (either masks or bounding-boxes) used for computing IOU. Supported IOU types are"bbox"
or"segm"
or both as a tuple.iou_thresholds (
Optional
[List
[float
]]) – IoU thresholds for evaluation. If set toNone
it corresponds to the stepped range[0.5,...,0.95]
with step0.05
. Else provide a list of floats.rec_thresholds (
Optional
[List
[float
]]) – Recall thresholds for evaluation. If set toNone
it corresponds to the stepped range[0,...,1]
with step0.01
. Else provide a list of floats.max_detection_thresholds (
Optional
[List
[int
]]) – Thresholds on max detections per image. If set to None will use thresholds[1, 10, 100]
. Else, please provide a list of ints.class_metrics (
bool
) – Option to enable per-class metrics for mAP and mAR_100. Has a performance impact that scales linearly with the number of classes in the dataset.extended_summary (
bool
) –Option to enable extended summary with additional metrics including IOU, precision and recall. The output dictionary will contain the following extra key-values:
ious
: a dictionary containing the IoU values for every image/class combination e.g.ious[(0,0)]
would contain the IoU for image 0 and class 0. Each value is a tensor with shape(n,m)
wheren
is the number of detections andm
is the number of ground truth boxes for that image/class combination.precision
: a tensor of shape(TxRxKxAxM)
containing the precision values. HereT
is the number of IoU thresholds,R
is the number of recall thresholds,K
is the number of classes,A
is the number of areas andM
is the number of max detections per image.recall
: a tensor of shape(TxKxAxM)
containing the recall values. HereT
is the number of IoU thresholds,K
is the number of classes,A
is the number of areas andM
is the number of max detections per image.
average (
Literal
['macro'
,'micro'
]) – Method for averaging scores over labels. Choose between “macro
”” and “micro
”. Default is “macro”kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
pycocotools
is not installedModuleNotFoundError – If
torchvision
is not installed or version installed is lower than 0.8.0ValueError – If
box_format
is not one of"xyxy"
,"xywh"
or"cxcywh"
ValueError – If
iou_type
is not one of"bbox"
or"segm"
ValueError – If
iou_thresholds
is not None or a list of floatsValueError – If
rec_thresholds
is not None or a list of floatsValueError – If
max_detection_thresholds
is not None or a list of intsValueError – If
class_metrics
is not a boolean
Example:
Basic example for when `iou_type="bbox"`. In this case the ``boxes`` key is required in the input dictionaries, in addition to the ``scores`` and ``labels`` keys. >>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision >>> preds = [ ... dict( ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), ... scores=tensor([0.536]), ... labels=tensor([0]), ... ) ... ] >>> target = [ ... dict( ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), ... labels=tensor([0]), ... ) ... ] >>> metric = MeanAveragePrecision(iou_type="bbox") >>> metric.update(preds, target) >>> from pprint import pprint >>> pprint(metric.compute()) {'classes': tensor(0, dtype=torch.int32), 'map': tensor(0.6000), 'map_50': tensor(1.), 'map_75': tensor(1.), 'map_large': tensor(0.6000), 'map_medium': tensor(-1.), 'map_per_class': tensor(-1.), 'map_small': tensor(-1.), 'mar_1': tensor(0.6000), 'mar_10': tensor(0.6000), 'mar_100': tensor(0.6000), 'mar_100_per_class': tensor(-1.), 'mar_large': tensor(0.6000), 'mar_medium': tensor(-1.), 'mar_small': tensor(-1.)}
Example:
Basic example for when `iou_type="segm"`. In this case the ``masks`` key is required in the input dictionaries, in addition to the ``scores`` and ``labels`` keys. >>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision >>> mask_pred = [ ... [0, 0, 0, 0, 0], ... [0, 0, 1, 1, 0], ... [0, 0, 1, 1, 0], ... [0, 0, 0, 0, 0], ... [0, 0, 0, 0, 0], ... ] >>> mask_tgt = [ ... [0, 0, 0, 0, 0], ... [0, 0, 1, 0, 0], ... [0, 0, 1, 1, 0], ... [0, 0, 1, 0, 0], ... [0, 0, 0, 0, 0], ... ] >>> preds = [ ... dict( ... masks=tensor([mask_pred], dtype=torch.bool), ... scores=tensor([0.536]), ... labels=tensor([0]), ... ) ... ] >>> target = [ ... dict( ... masks=tensor([mask_tgt], dtype=torch.bool), ... labels=tensor([0]), ... ) ... ] >>> metric = MeanAveragePrecision(iou_type="segm") >>> metric.update(preds, target) >>> from pprint import pprint >>> pprint(metric.compute()) {'classes': tensor(0, dtype=torch.int32), 'map': tensor(0.2000), 'map_50': tensor(1.), 'map_75': tensor(0.), 'map_large': tensor(-1.), 'map_medium': tensor(-1.), 'map_per_class': tensor(-1.), 'map_small': tensor(0.2000), 'mar_1': tensor(0.2000), 'mar_10': tensor(0.2000), 'mar_100': tensor(0.2000), 'mar_100_per_class': tensor(-1.), 'mar_large': tensor(-1.), 'mar_medium': tensor(-1.), 'mar_small': tensor(0.2000)}
- static coco_to_tm(coco_preds, coco_target, iou_type='bbox')[source]¶
Utility function for converting .json coco format files to the input format of this metric.
The function accepts a file for the predictions and a file for the target in coco format and converts them to a list of dictionaries containing the boxes, labels and scores in the input format of this metric.
- Parameters:
- Return type:
- Returns:
A tuple containing the predictions and targets in the input format of this metric. Each element of the tuple is a list of dictionaries containing the boxes, labels and scores.
Example
>>> # File formats are defined at https://cocodataset.org/#format-data >>> # Example files can be found at >>> # https://github.com/cocodataset/cocoapi/tree/master/results >>> from torchmetrics.detection import MeanAveragePrecision >>> preds, target = MeanAveragePrecision.coco_to_tm( ... "instances_val2014_fakebbox100_results.json.json", ... "val2014_fake_eval_res.txt.json" ... iou_type="bbox" ... )
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Dict
[str
,Tensor
],Sequence
[Dict
[str
,Tensor
]],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import tensor >>> from torchmetrics.detection.mean_ap import MeanAveragePrecision >>> preds = [dict( ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), ... scores=tensor([0.536]), ... labels=tensor([0]), ... )] >>> target = [dict( ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), ... labels=tensor([0]), ... )] >>> metric = MeanAveragePrecision() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.detection.mean_ap import MeanAveragePrecision >>> preds = lambda: [dict( ... boxes=torch.tensor([[258.0, 41.0, 606.0, 285.0]]) + torch.randint(10, (1,4)), ... scores=torch.tensor([0.536]) + 0.1*torch.rand(1), ... labels=torch.tensor([0]), ... )] >>> target = [dict( ... boxes=torch.tensor([[214.0, 41.0, 562.0, 285.0]]), ... labels=torch.tensor([0]), ... )] >>> metric = MeanAveragePrecision() >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds(), target)) >>> fig_, ax_ = metric.plot(vals)
- tm_to_coco(name='tm_map_input')[source]¶
Utility function for converting the input for this metric to coco format and saving it to a json file.
This function should be used after calling .update(…) or .forward(…) on all data that should be written to the file, as the input is then internally cached. The function then converts to information to coco format a writes it to json files.
- Parameters:
name (
str
) – Name of the output file, which will be appended with “_preds.json” and “_target.json”- Return type:
Example
>>> from torch import tensor >>> from torchmetrics.detection import MeanAveragePrecision >>> preds = [ ... dict( ... boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), ... scores=tensor([0.536]), ... labels=tensor([0]), ... ) ... ] >>> target = [ ... dict( ... boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), ... labels=tensor([0]), ... ) ... ] >>> metric = MeanAveragePrecision() >>> metric.update(preds, target) >>> metric.tm_to_coco("tm_map_input")
Modified Panoptic Quality¶
Module Interface¶
- class torchmetrics.detection.ModifiedPanopticQuality(things, stuffs, allow_unknown_preds_category=False, **kwargs)[source]¶
Compute Modified Panoptic Quality for panoptic segmentations.
The metric was introduced in Seamless Scene Segmentation paper, and is an adaptation of the original Panoptic Quality where the metric for a stuff class is computed as
\[PQ^{\dagger}_c = \frac{IOU_c}{|S_c|}\]where \(IOU_c\) is the sum of the intersection over union of all matching segments for a given class, and \(|S_c|\) is the overall number of segments in the ground truth for that class.
- Parameters:
things (
Collection
[int
]) – Set ofcategory_id
for countable things.stuffs (
Collection
[int
]) – Set ofcategory_id
for uncountable stuffs.allow_unknown_preds_category (
bool
) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.
- Raises:
ValueError – If
things
,stuffs
have at least one commoncategory_id
.TypeError – If
things
,stuffs
contain non-integercategory_id
.
Example
>>> from torch import tensor >>> from torchmetrics.detection import ModifiedPanopticQuality >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]]) >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]]) >>> pq_modified = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) >>> pq_modified(preds, target) tensor(0.7667, dtype=torch.float64)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import tensor >>> from torchmetrics.detection import ModifiedPanopticQuality >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [7, 0], [6, 0], [1, 0]], ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [1, 0]], ... [[0, 1], [7, 0], [1, 0], [1, 0]], ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) >>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import tensor >>> from torchmetrics.detection import ModifiedPanopticQuality >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [7, 0], [6, 0], [1, 0]], ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [1, 0]], ... [[0, 1], [7, 0], [1, 0], [1, 0]], ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) >>> metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(vals)
Functional Interface¶
- torchmetrics.functional.detection.modified_panoptic_quality(preds, target, things, stuffs, allow_unknown_preds_category=False)[source]¶
Compute Modified Panoptic Quality for panoptic segmentations.
The metric was introduced in Seamless Scene Segmentation paper, and is an adaptation of the original Panoptic Quality where the metric for a stuff class is computed as
\[PQ^{\dagger}_c = \frac{IOU_c}{|S_c|}\]where \(IOU_c\) is the sum of the intersection over union of all matching segments for a given class, and \(|S_c|\) is the overall number of segments in the ground truth for that class.
- Parameters:
preds (
Tensor
) – torch tensor with panoptic detection of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.target (
Tensor
) – torch tensor with ground truth of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.things (
Collection
[int
]) – Set ofcategory_id
for countable things.stuffs (
Collection
[int
]) – Set ofcategory_id
for uncountable stuffs.allow_unknown_preds_category (
bool
) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.
- Raises:
ValueError – If
things
,stuffs
have at least one commoncategory_id
.TypeError – If
things
,stuffs
contain non-integercategory_id
.TypeError – If
preds
ortarget
is not antorch.Tensor
.ValueError – If
preds
ortarget
has different shape.ValueError – If
preds
has less than 3 dimensions.ValueError – If the final dimension of
preds
has size != 2.
- Return type:
Example
>>> from torch import tensor >>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]]) >>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]]) >>> modified_panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}) tensor(0.7667, dtype=torch.float64)
Panoptic Quality¶
Module Interface¶
- class torchmetrics.detection.PanopticQuality(things, stuffs, allow_unknown_preds_category=False, **kwargs)[source]¶
Compute the Panoptic Quality for panoptic segmentations.
\[PQ = \frac{IOU}{TP + 0.5 FP + 0.5 FN}\]where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives, the number of true postitives, false positives and false negatives. This metric is inspired by the PQ implementation of panopticapi, a standard implementation for the PQ metric for panoptic segmentation.
- Parameters:
things (
Collection
[int
]) – Set ofcategory_id
for countable things.stuffs (
Collection
[int
]) – Set ofcategory_id
for uncountable stuffs.allow_unknown_preds_category (
bool
) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.
- Raises:
ValueError – If
things
,stuffs
have at least one commoncategory_id
.TypeError – If
things
,stuffs
contain non-integercategory_id
.
Example
>>> from torch import tensor >>> from torchmetrics.detection import PanopticQuality >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [7, 0], [6, 0], [1, 0]], ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [1, 0]], ... [[0, 1], [7, 0], [1, 0], [1, 0]], ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) >>> panoptic_quality = PanopticQuality(things = {0, 1}, stuffs = {6, 7}) >>> panoptic_quality(preds, target) tensor(0.5463, dtype=torch.float64)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure object and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import tensor >>> from torchmetrics.detection import PanopticQuality >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [7, 0], [6, 0], [1, 0]], ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [1, 0]], ... [[0, 1], [7, 0], [1, 0], [1, 0]], ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) >>> metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7}) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import tensor >>> from torchmetrics.detection import PanopticQuality >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [7, 0], [6, 0], [1, 0]], ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [1, 0]], ... [[0, 1], [7, 0], [1, 0], [1, 0]], ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) >>> metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7}) >>> vals = [] >>> for _ in range(20): ... vals.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(vals)
Functional Interface¶
- torchmetrics.functional.detection.panoptic_quality(preds, target, things, stuffs, allow_unknown_preds_category=False)[source]¶
Compute Panoptic Quality for panoptic segmentations.
\[PQ = \frac{IOU}{TP + 0.5 FP + 0.5 FN}\]where IOU, TP, FP and FN are respectively the sum of the intersection over union for true positives, the number of true postitives, false positives and false negatives. This metric is inspired by the PQ implementation of panopticapi, a standard implementation for the PQ metric for object detection.
- Parameters:
preds (
Tensor
) – torch tensor with panoptic detection of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.target (
Tensor
) – torch tensor with ground truth of shape [height, width, 2] containing the pair (category_id, instance_id) for each pixel of the image. If the category_id refer to a stuff, the instance_id is ignored.things (
Collection
[int
]) – Set ofcategory_id
for countable things.stuffs (
Collection
[int
]) – Set ofcategory_id
for uncountable stuffs.allow_unknown_preds_category (
bool
) – Boolean flag to specify if unknown categories in the predictions are to be ignored in the metric computation or raise an exception when found.
- Raises:
ValueError – If
things
,stuffs
have at least one commoncategory_id
.TypeError – If
things
,stuffs
contain non-integercategory_id
.TypeError – If
preds
ortarget
is not antorch.Tensor
.ValueError – If
preds
ortarget
has different shape.ValueError – If
preds
has less than 3 dimensions.ValueError – If the final dimension of
preds
has size != 2.
- Return type:
Example
>>> from torch import tensor >>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [0, 0], [6, 0], [0, 1]], ... [[0, 0], [7, 0], [6, 0], [1, 0]], ... [[0, 0], [7, 0], [7, 0], [7, 0]]]]) >>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [0, 1]], ... [[0, 1], [0, 1], [6, 0], [1, 0]], ... [[0, 1], [7, 0], [1, 0], [1, 0]], ... [[0, 1], [7, 0], [7, 0], [7, 0]]]]) >>> panoptic_quality(preds, target, things = {0, 1}, stuffs = {6, 7}) tensor(0.5463, dtype=torch.float64)
Error Relative Global Dim. Synthesis (ERGAS)¶
Module Interface¶
- class torchmetrics.image.ErrorRelativeGlobalDimensionlessSynthesis(ratio=4, reduction='elementwise_mean', **kwargs)[source]¶
Calculate Relative dimensionless global error synthesis (ERGAS).
This metric is used to calculate the accuracy of Pan sharpened image considering normalized average error of each band of the result image.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
ergas
(Tensor
): ifreduction!='none'
returns float scalar tensor with average ERGAS value over sample else returns tensor of shape(N,)
with ERGAS values per sample
- Parameters:
ratio (
Union
[int
,float
]) – ratio of high resolution to low resolutionreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
orNone
: no reduction will be applied
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> ergas = ErrorRelativeGlobalDimensionlessSynthesis() >>> torch.round(ergas(preds, target)) tensor(154.)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import ErrorRelativeGlobalDimensionlessSynthesis >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.error_relative_global_dimensionless_synthesis(preds, target, ratio=4, reduction='elementwise_mean')[source]¶
Erreur Relative Globale Adimensionnelle de Synthèse.
- Parameters:
preds (
Tensor
) – estimated imagetarget (
Tensor
) – ground truth imageratio (
Union
[int
,float
]) – ratio of high resolution to low resolutionreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
orNone
: no reduction will be applied
- Return type:
- Returns:
Tensor with RelativeG score
- Raises:
TypeError – If
preds
andtarget
don’t have the same data type.ValueError – If
preds
andtarget
don’t haveBxCxHxW shape
.
Example
>>> from torchmetrics.functional.image import error_relative_global_dimensionless_synthesis >>> gen = torch.manual_seed(42) >>> preds = torch.rand([16, 1, 16, 16], generator=gen) >>> target = preds * 0.75 >>> ergds = error_relative_global_dimensionless_synthesis(preds, target) >>> torch.round(ergds) tensor(154.)
References
[1] Qian Du; Nicholas H. Younan; Roger King; Vijay P. Shah, “On the Performance Evaluation of Pan-Sharpening Techniques” in IEEE Geoscience and Remote Sensing Letters, vol. 4, no. 4, pp. 518-522, 15 October 2007, doi: 10.1109/LGRS.2007.896328.
Frechet Inception Distance (FID)¶
Module Interface¶
- class torchmetrics.image.fid.FrechetInceptionDistance(feature=2048, reset_real_features=True, normalize=False, **kwargs)[source]¶
Calculate Fréchet inception distance (FID) which is used to access the quality of generated images.
\[FID = \|\mu - \mu_w\|^2 + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})\]where \(\mathcal{N}(\mu, \Sigma)\) is the multivariate normal distribution estimated from Inception v3 (fid ref1) features calculated on real life images and \(\mathcal{N}(\mu_w, \Sigma_w)\) is the multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images. The metric was originally proposed in fid ref1.
Using the default feature extraction (Inception v3 using the original weights from fid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape
(3xHxW)
. If argumentnormalize
isTrue
images are expected to be dtypefloat
and have values in the[0,1]
range, else ifnormalize
is set toFalse
images are expected to have dtypeuint8
and take values in the[0, 255]
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flagreal
determines if the images should update the statistics of the real distribution or the fake distribution.This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric that you calculate using torch.float64 (default is torch.float32) which can be set using the .set_dtype method of the metric.
Note
using this metrics requires you to have torch 1.9 or higher installed
Note
using this metric with the default feature extractor requires that
torch-fidelity
is installed. Either install aspip install torchmetrics[image]
orpip install torch-fidelity
As input to
forward
andupdate
the metric accepts the following inputimgs
(Tensor
): tensor with images feed to the feature extractor withreal
(bool
): bool indicating ifimgs
belong to the real or the fake distribution
As output of forward and compute the metric returns the following output
fid
(Tensor
): float scalar tensor with mean FID value over samples
- Parameters:
feature (
Union
[int
,Module
]) –Either an integer or
nn.Module
:an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: 64, 192, 768, 2048
an
nn.Module
for using a custom feature extractor. Expects that its forward method returns an(N,d)
matrix whereN
is the batch size andd
is the feature size.
reset_real_features (
bool
) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can be cached them to avoid recomputing them which is costly. Set this toFalse
if your dataset does not change.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If torch version is lower than 1.9
ModuleNotFoundError – If
feature
is set to anint
(default settings) andtorch-fidelity
is not installedValueError – If
feature
is set to anint
not in [64, 192, 768, 2048]TypeError – If
feature
is not anstr
,int
ortorch.nn.Module
ValueError – If
reset_real_features
is not anbool
Example
>>> import torch >>> _ = torch.manual_seed(123) >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> fid = FrechetInceptionDistance(feature=64) >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> fid.update(imgs_dist1, real=True) >>> fid.update(imgs_dist2, real=False) >>> fid.compute() tensor(12.7202)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = FrechetInceptionDistance(feature=64) >>> metric.update(imgs_dist1, real=True) >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = FrechetInceptionDistance(feature=64) >>> values = [ ] >>> for _ in range(3): ... metric.update(imgs_dist1(), real=True) ... metric.update(imgs_dist2(), real=False) ... values.append(metric.compute()) ... metric.reset() >>> fig_, ax_ = metric.plot(values)
Image Gradients¶
Functional Interface¶
- torchmetrics.functional.image.image_gradients(img)[source]¶
Compute Gradient Computation of Image of a given image using finite difference.
- Parameters:
img (
Tensor
) – An(N, C, H, W)
input tensor whereC
is the number of image channels- Return type:
- Returns:
Tuple of
(dy, dx)
with each gradient of shape[N, C, H, W]
- Raises:
RuntimeError – If
img
is not a 4D tensor.
Example
>>> from torchmetrics.functional.image import image_gradients >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32) >>> image = torch.reshape(image, (1, 1, 5, 5)) >>> dy, dx = image_gradients(image) >>> dy[0, 0, :, :] tensor([[5., 5., 5., 5., 5.], [5., 5., 5., 5., 5.], [5., 5., 5., 5., 5.], [5., 5., 5., 5., 5.], [0., 0., 0., 0., 0.]])
Note
The implementation follows the 1-step finite difference method as followed by the TF implementation. The values are organized such that the gradient of [I(x+1, y)-[I(x, y)]] are at the (x, y) location
Inception Score¶
Module Interface¶
- class torchmetrics.image.inception.InceptionScore(feature='logits_unbiased', splits=10, normalize=False, **kwargs)[source]¶
Calculate the Inception Score (IS) which is used to access how realistic generated images are.
\[IS = exp(\mathbb{E}_x KL(p(y | x ) || p(y)))\]where \(KL(p(y | x) || p(y))\) is the KL divergence between the conditional distribution \(p(y|x)\) and the margianl distribution \(p(y)\). Both the conditional and marginal distribution is calculated from features extracted from the images. The score is calculated on random splits of the images such that both a mean and standard deviation of the score are returned. The metric was originally proposed in inception ref1.
Using the default feature extraction (Inception v3 using the original weights from inception ref2), the input is expected to be mini-batches of 3-channel RGB images of shape
(3xHxW)
. If argumentnormalize
isTrue
images are expected to be dtypefloat
and have values in the[0,1]
range, else ifnormalize
is set toFalse
images are expected to have dtype uint8 and take values in the[0, 255]
range. All images will be resized to 299 x 299 which is the size of the original training data.Note
using this metric with the default feature extractor requires that
torch-fidelity
is installed. Either install aspip install torchmetrics[image]
orpip install torch-fidelity
As input to
forward
andupdate
the metric accepts the following inputimgs
(Tensor
): tensor with images feed to the feature extractor
As output of forward and compute the metric returns the following output
fid
(Tensor
): float scalar tensor with mean FID value over samples
- Parameters:
feature (
Union
[str
,int
,Module
]) –Either an str, integer or
nn.Module
:an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048
an
nn.Module
for using a custom feature extractor. Expects that its forward method returns an(N,d)
matrix whereN
is the batch size andd
is the feature size.
splits (
int
) – integer determining how many splits the inception score calculation should be split amongkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
feature
is set to anstr
orint
andtorch-fidelity
is not installedValueError – If
feature
is set to anstr
orint
and not one of('logits_unbiased', 64, 192, 768, 2048)
TypeError – If
feature
is not anstr
,int
ortorch.nn.Module
Example
>>> import torch >>> _ = torch.manual_seed(123) >>> from torchmetrics.image.inception import InceptionScore >>> inception = InceptionScore() >>> # generate some images >>> imgs = torch.randint(0, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> inception.update(imgs) >>> inception.compute() (tensor(1.0544), tensor(0.0117))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.inception import InceptionScore >>> metric = InceptionScore() >>> metric.update(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8)) >>> fig_, ax_ = metric.plot() # the returned plot only shows the mean value by default
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.inception import InceptionScore >>> metric = InceptionScore() >>> values = [ ] >>> for _ in range(3): ... # we index by 0 such that only the mean value is plotted ... values.append(metric(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))[0]) >>> fig_, ax_ = metric.plot(values)
Kernel Inception Distance¶
Module Interface¶
- class torchmetrics.image.kid.KernelInceptionDistance(feature=2048, subsets=100, subset_size=1000, degree=3, gamma=None, coef=1.0, reset_real_features=True, normalize=False, **kwargs)[source]¶
Calculate Kernel Inception Distance (KID) which is used to access the quality of generated images.
\[KID = MMD(f_{real}, f_{fake})^2\]where \(MMD\) is the maximum mean discrepancy and \(I_{real}, I_{fake}\) are extracted features from real and fake images, see kid ref1 for more details. In particular, calculating the MMD requires the evaluation of a polynomial kernel function \(k\)
\[k(x,y) = (\gamma * x^T y + coef)^{degree}\]which controls the distance between two features. In practise the MMD is calculated over a number of subsets to be able to both get the mean and standard deviation of KID.
Using the default feature extraction (Inception v3 using the original weights from kid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape
(3xHxW)
. If argumentnormalize
isTrue
images are expected to be dtypefloat
and have values in the[0,1]
range, else ifnormalize
is set toFalse
images are expected to have dtypeuint8
and take values in the[0, 255]
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flagreal
determines if the images should update the statistics of the real distribution or the fake distribution.Note
using this metric with the default feature extractor requires that
torch-fidelity
is installed. Either install aspip install torchmetrics[image]
orpip install torch-fidelity
As input to
forward
andupdate
the metric accepts the following inputimgs
(Tensor
): tensor with images feed to the feature extractor of shape(N,C,H,W)
real
(bool): bool indicating ifimgs
belong to the real or the fake distribution
As output of forward and compute the metric returns the following output
kid_mean
(Tensor
): float scalar tensor with mean value over subsetskid_std
(Tensor
): float scalar tensor with mean value over subsets
- Parameters:
feature (
Union
[str
,int
,Module
]) –Either an str, integer or
nn.Module
:an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048
an
nn.Module
for using a custom feature extractor. Expects that its forward method returns an(N,d)
matrix whereN
is the batch size andd
is the feature size.
subsets (
int
) – Number of subsets to calculate the mean and standard deviation scores oversubset_size (
int
) – Number of randomly picked samples in each subsetdegree (
int
) – Degree of the polynomial kernel functiongamma (
Optional
[float
]) – Scale-length of polynomial kernel. If set toNone
will be automatically set to the feature sizecoef (
float
) – Bias term in the polynomial kernel.reset_real_features (
bool
) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this toFalse
if your dataset does not change.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
feature
is set to anint
(default settings) andtorch-fidelity
is not installedValueError – If
feature
is set to anint
not in(64, 192, 768, 2048)
ValueError – If
subsets
is not an integer larger than 0ValueError – If
subset_size
is not an integer larger than 0ValueError – If
degree
is not an integer larger than 0ValueError – If
gamma
is nietherNone
or a float larger than 0ValueError – If
coef
is not an float larger than 0ValueError – If
reset_real_features
is not anbool
Example
>>> import torch >>> _ = torch.manual_seed(123) >>> from torchmetrics.image.kid import KernelInceptionDistance >>> kid = KernelInceptionDistance(subset_size=50) >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> kid.update(imgs_dist1, real=True) >>> kid.update(imgs_dist2, real=False) >>> kid.compute() (tensor(0.0337), tensor(0.0023))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.kid import KernelInceptionDistance >>> imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) >>> metric.update(imgs_dist1, real=True) >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.kid import KernelInceptionDistance >>> imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) >>> values = [ ] >>> for _ in range(3): ... metric.update(imgs_dist1(), real=True) ... metric.update(imgs_dist2(), real=False) ... values.append(metric.compute()[0]) ... metric.reset() >>> fig_, ax_ = metric.plot(values)
Learned Perceptual Image Patch Similarity (LPIPS)¶
Module Interface¶
- class torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type='alex', reduction='mean', normalize=False, **kwargs)[source]¶
The Learned Perceptual Image Patch Similarity (LPIPS_) calculates the perceptual similarity between two images.
LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar.
Both input image patches are expected to have shape
(N, 3, H, W)
. The minimum size of H, W depends on the chosen backbone (see net_type arg).Note
using this metrics requires you to have
lpips
package installed. Either install aspip install torchmetrics[image]
orpip install lpips
Note
this metric is not scriptable when using
torch<1.8
. Please update your pytorch installation if this is a issue.As input to
forward
andupdate
the metric accepts the following inputimg1
(Tensor
): tensor with images of shape(N, 3, H, W)
img2
(Tensor
): tensor with images of shape(N, 3, H, W)
As output of forward and compute the metric returns the following output
lpips
(Tensor
): returns float scalar tensor with average LPIPS value over samples
- Parameters:
net_type (
Literal
['vgg'
,'alex'
,'squeeze'
]) – str indicating backbone network type to use. Choose between ‘alex’, ‘vgg’ or ‘squeeze’reduction (
Literal
['sum'
,'mean'
]) – str indicating how to reduce over the batch dimension. Choose between ‘sum’ or ‘mean’.normalize (
bool
) – by default this isFalse
meaning that the input is expected to be in the [-1,1] range. If set toTrue
will instead expect input to be in the[0,1]
range.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
lpips
package is not installedValueError – If
net_type
is not one of"vgg"
,"alex"
or"squeeze"
ValueError – If
reduction
is not one of"mean"
or"sum"
Example
>>> import torch >>> _ = torch.manual_seed(123) >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity >>> lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze') >>> # LPIPS needs the images to be in the [-1, 1] range. >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> lpips(img1, img2) tensor(0.1046, grad_fn=<SqueezeBackward0>)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity >>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze') >>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity >>> metric = LearnedPerceptualImagePatchSimilarity(net_type='squeeze') >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.learned_perceptual_image_patch_similarity(img1, img2, net_type='alex', reduction='mean', normalize=False)[source]¶
The Learned Perceptual Image Patch Similarity (LPIPS_) calculates the perceptual similarity between two images.
LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar.
Both input image patches are expected to have shape
(N, 3, H, W)
. The minimum size of H, W depends on the chosen backbone (see net_type arg).- Parameters:
img1 (
Tensor
) – first set of imagesimg2 (
Tensor
) – second set of imagesnet_type (
Literal
['alex'
,'vgg'
,'squeeze'
]) – str indicating backbone network type to use. Choose between ‘alex’, ‘vgg’ or ‘squeeze’reduction (
Literal
['sum'
,'mean'
]) – str indicating how to reduce over the batch dimension. Choose between ‘sum’ or ‘mean’.normalize (
bool
) – by default this isFalse
meaning that the input is expected to be in the [-1,1] range. If set toTrue
will instead expect input to be in the[0,1]
range.
- Return type:
Example
>>> import torch >>> _ = torch.manual_seed(123) >>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='squeeze') tensor(0.1008, grad_fn=<DivBackward0>)
Memorization-Informed Frechet Inception Distance (MiFID)¶
Module Interface¶
- class torchmetrics.image.mifid.MemorizationInformedFrechetInceptionDistance(feature=2048, reset_real_features=True, normalize=False, cosine_distance_eps=0.1, **kwargs)[source]¶
Calculate Memorization-Informed Frechet Inception Distance (MIFID).
MIFID is a improved variation of the Frechet Inception Distance (FID) that penalizes memorization of the training set by the generator. It is calculated as
\[MIFID = \frac{FID(F_{real}, F_{fake})}{M(F_{real}, F_{fake})}\]where \(FID\) is the normal FID score and \(M\) is the memorization penalty. The memorization penalty essentially corresponds to the average minimum cosine distance between the features of the real and fake distribution.
Using the default feature extraction (Inception v3 using the original weights from fid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape
(3 x H x W)
. If argumentnormalize
isTrue
images are expected to be dtypefloat
and have values in the[0, 1]
range, else ifnormalize
is set toFalse
images are expected to have dtypeuint8
and take values in the[0, 255]
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flagreal
determines if the images should update the statistics of the real distribution or the fake distribution.Note
using this metrics requires you to have
scipy
install. Either install aspip install torchmetrics[image]
orpip install scipy
Note
using this metric with the default feature extractor requires that
torch-fidelity
is installed. Either install aspip install torchmetrics[image]
orpip install torch-fidelity
As input to
forward
andupdate
the metric accepts the following inputimgs
(Tensor
): tensor with images feed to the feature extractor withreal
(bool
): bool indicating ifimgs
belong to the real or the fake distribution
As output of forward and compute the metric returns the following output
mifid
(Tensor
): float scalar tensor with mean MIFID value over samples
- Parameters:
feature (
Union
[int
,Module
]) –Either an integer or
nn.Module
:an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: 64, 192, 768, 2048
an
nn.Module
for using a custom feature extractor. Expects that its forward method returns an(N,d)
matrix whereN
is the batch size andd
is the feature size.
reset_real_features (
bool
) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can be cached them to avoid recomputing them which is costly. Set this toFalse
if your dataset does not change.cosine_distance_eps (
float
) – Epsilon value for the cosine distance. If the cosine distance is larger than this value it is set to 1 and thus ignored in the MIFID calculation.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
RuntimeError – If
torch
is version less than 1.10ValueError – If
feature
is set to anint
andtorch-fidelity
is not installedValueError – If
feature
is set to anint
not in [64, 192, 768, 2048]TypeError – If
feature
is not anstr
,int
ortorch.nn.Module
ValueError – If
reset_real_features
is not anbool
- Example::
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance >>> mifid = MemorizationInformedFrechetInceptionDistance(feature=64) >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> mifid.update(imgs_dist1, real=True) >>> mifid.update(imgs_dist2, real=False) >>> mifid.compute() tensor(3003.3691)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = MemorizationInformedFrechetInceptionDistance(feature=64) >>> metric.update(imgs_dist1, real=True) >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = MemorizationInformedFrechetInceptionDistance(feature=64) >>> values = [ ] >>> for _ in range(3): ... metric.update(imgs_dist1(), real=True) ... metric.update(imgs_dist2(), real=False) ... values.append(metric.compute()) ... metric.reset() >>> fig_, ax_ = metric.plot(values)
Multi-Scale SSIM¶
Module Interface¶
- class torchmetrics.image.MultiScaleStructuralSimilarityIndexMeasure(gaussian_kernel=True, kernel_size=11, sigma=1.5, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize='relu', **kwargs)[source]¶
Compute MultiScaleSSIM, Multi-scale Structural Similarity Index Measure.
This metric is is a generalization of Structural Similarity Index Measure by incorporating image details at different resolution scores.
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
msssim
(Tensor
): ifreduction!='none'
returns float scalar tensor with average MSSSIM value over sample else returns tensor of shape(N,)
with SSIM values per sample
- Parameters:
gaussian_kernel (
bool
) – IfTrue
(default), a gaussian kernel is used, if false a uniform kernel is usedkernel_size (
Union
[int
,Sequence
[int
]]) – size of the gaussian kernelsigma (
Union
[float
,Sequence
[float
]]) – Standard deviation of the gaussian kernelreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean'sum'
: takes the sum'none'
orNone
: no reduction will be applied
data_range (
Union
[float
,Tuple
[float
,float
],None
]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. Thedata_range
must be given whendim
is not None.k1 (
float
) – Parameter of structural similarity index measure.k2 (
float
) – Parameter of structural similarity index measure.betas (
Tuple
[float
,...
]) – Exponent parameters for individual similarities and contrastive sensitivies returned by different image resolutions.normalize (
Literal
['relu'
,'simple'
,None
]) – When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use normalizes to improve the training stability. This normalize argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
Tensor with Multi-Scale SSIM score
- Raises:
ValueError – If
kernel_size
is not an int or a Sequence of ints with size 2 or 3.ValueError – If
betas
is not a tuple of floats with lengt 2.ValueError – If
normalize
is neither None, ReLU nor simple.
Example
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure >>> import torch >>> gen = torch.manual_seed(42) >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> ms_ssim(preds, target) tensor(0.9627)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure >>> import torch >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure >>> import torch >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) >>> target = preds * 0.75 >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.multiscale_structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize='relu')[source]¶
Compute MultiScaleSSIM, Multi-scale Structual Similarity Index Measure.
This metric is a generalization of Structual Similarity Index Measure by incorporating image details at different resolution scores.
- Parameters:
preds (
Tensor
) – Predictions from model of shape[N, C, H, W]
target (
Tensor
) – Ground truth values of shape[N, C, H, W]
gaussian_kernel (
bool
) – If true, a gaussian kernel is used, if false a uniform kernel is usedsigma (
Union
[float
,Sequence
[float
]]) – Standard deviation of the gaussian kernelkernel_size (
Union
[int
,Sequence
[int
]]) – size of the gaussian kernelreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean'sum'
: takes the sum'none'
orNone
: no reduction will be applied
data_range (
Union
[float
,Tuple
[float
,float
],None
]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.k1 (
float
) – Parameter of structural similarity index measure.k2 (
float
) – Parameter of structural similarity index measure.betas (
Tuple
[float
,...
]) – Exponent parameters for individual similarities and contrastive sensitivies returned by different image resolutions.normalize (
Optional
[Literal
['relu'
,'simple'
]]) – When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the training stability. This normalize argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
- Return type:
- Returns:
Tensor with Multi-Scale SSIM score
- Raises:
TypeError – If
preds
andtarget
don’t have the same data type.ValueError – If
preds
andtarget
don’t haveBxCxHxW shape
.ValueError – If the length of
kernel_size
orsigma
is not2
.ValueError – If one of the elements of
kernel_size
is not anodd positive number
.ValueError – If one of the elements of
sigma
is not apositive number
.
Example
>>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure >>> gen = torch.manual_seed(42) >>> preds = torch.rand([3, 3, 256, 256], generator=gen) >>> target = preds * 0.75 >>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) tensor(0.9627)
References
[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. Bovik MultiScaleSSIM
Peak Signal-to-Noise Ratio (PSNR)¶
Module Interface¶
- class torchmetrics.image.PeakSignalNoiseRatio(data_range=None, base=10.0, reduction='elementwise_mean', dim=None, **kwargs)[source]¶
Compute Peak Signal-to-Noise Ratio (PSNR).
\[\text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right)\]Where \(\text{MSE}\) denotes the mean-squared-error function.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,C,H,W)
target
(Tensor
): Ground truth values of shape(N,C,H,W)
As output of forward and compute the metric returns the following output
psnr
(Tensor
): ifreduction!='none'
returns float scalar tensor with average PSNR value over sample else returns tensor of shape(N,)
with PSNR values per sample
- Parameters:
data_range (
Union
[float
,Tuple
[float
,float
],None
]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. Thedata_range
must be given whendim
is not None.base (
float
) – a base of a logarithm to use.reduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
orNone
: no reduction will be applied
dim (
Union
[int
,Tuple
[int
,...
],None
]) – Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions and all batches.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
dim
is notNone
anddata_range
is not given.
Example
>>> from torchmetrics.image import PeakSignalNoiseRatio >>> psnr = PeakSignalNoiseRatio() >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) >>> psnr(preds, target) tensor(2.5527)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import PeakSignalNoiseRatio >>> metric = PeakSignalNoiseRatio() >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import PeakSignalNoiseRatio >>> metric = PeakSignalNoiseRatio() >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.peak_signal_noise_ratio(preds, target, data_range=None, base=10.0, reduction='elementwise_mean', dim=None)[source]¶
Compute the peak signal-to-noise ratio.
- Parameters:
preds (
Tensor
) – estimated signaltarget (
Tensor
) – groun truth signaldata_range (
Union
[float
,Tuple
[float
,float
],None
]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. Thedata_range
must be given whendim
is not None.base (
float
) – a base of a logarithm to usereduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
or None``: no reduction will be applied
dim (
Union
[int
,Tuple
[int
,...
],None
]) – Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is None meaning scores will be reduced across all dimensions.
- Return type:
- Returns:
Tensor with PSNR score
- Raises:
ValueError – If
dim
is notNone
anddata_range
is not provided.
Example
>>> from torchmetrics.functional.image import peak_signal_noise_ratio >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) >>> peak_signal_noise_ratio(pred, target) tensor(2.5527)
Note
Half precision is only support on GPU for this metric
Peak Signal To Noise Ratio With Blocked Effect¶
Module Interface¶
- class torchmetrics.image.PeakSignalNoiseRatioWithBlockedEffect(block_size=8, **kwargs)[source]¶
Computes Peak Signal to Noise Ratio With Blocked Effect (PSNRB).
\[\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)\]Where \(\text{MSE}\) denotes the mean-squared-error function. This metric is a modified version of PSNR that better supports evaluation of images with blocked artifacts, that oftens occur in compressed images.
Note
Metric only supports grayscale images. If you have RGB images, please convert them to grayscale first.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,1,H,W)
target
(Tensor
): Ground truth values of shape(N,1,H,W)
As output of forward and compute the metric returns the following output
psnrb
(Tensor
): float scalar tensor with aggregated PSNRB value
- Parameters:
block_size (
int
) – integer indication the block sizekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect >>> metric = PeakSignalNoiseRatioWithBlockedEffect() >>> _ = torch.manual_seed(42) >>> preds = torch.rand(2, 1, 10, 10) >>> target = torch.rand(2, 1, 10, 10) >>> metric(preds, target) tensor(7.2893)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect >>> metric = PeakSignalNoiseRatioWithBlockedEffect() >>> metric.update(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect >>> metric = PeakSignalNoiseRatioWithBlockedEffect() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.peak_signal_noise_ratio_with_blocked_effect(preds, target, block_size=8)[source]¶
Computes Peak Signal to Noise Ratio With Blocked Effect (PSNRB) metrics.
\[\text{PSNRB}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)-\text{B}(I, J)}\right)\]Where \(\text{MSE}\) denotes the mean-squared-error function.
- Parameters:
- Return type:
- Returns:
Tensor with PSNRB score
Example
>>> import torch >>> from torchmetrics.functional.image import peak_signal_noise_ratio_with_blocked_effect >>> _ = torch.manual_seed(42) >>> preds = torch.rand(1, 1, 28, 28) >>> target = torch.rand(1, 1, 28, 28) >>> peak_signal_noise_ratio_with_blocked_effect(preds, target) tensor(7.8402)
Perceptual Path Length (PPL)¶
Module Interface¶
- class torchmetrics.image.perceptual_path_length.PerceptualPathLength(num_samples=10000, conditional=False, batch_size=128, interpolation_method='lerp', epsilon=0.0001, resize=64, lower_discard=0.01, upper_discard=0.99, sim_net='vgg', **kwargs)[source]¶
Computes the perceptual path length (PPL) of a generator model.
The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is defined as
\[PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]\]where \(G\) is the generator, \(I\) is the interpolation function, \(D\) is a similarity metric, \(z_1\) and \(z_2\) are two sets of latent points, and \(t\) is a parameter between 0 and 1. The metric thus works by interpolating between two sets of latent points, and measuring the similarity between the generated images. The expectation is approximated by sampling \(z_1\) and \(z_2\) from the generator, and averaging the calculated distanced. The similarity metric \(D\) is by default the LPIPS metric, but can be changed by setting the sim_net argument.
The provided generator model must have a sample method with signature sample(num_samples: int) -> Tensor where the returned tensor has shape (num_samples, z_size). If the generator is conditional, it must also have a num_classes attribute. The forward method of the generator must have signature forward(z: Tensor) -> Tensor if conditional=False, and forward(z: Tensor, labels: Tensor) -> Tensor if conditional=True. The returned tensor should have shape (num_samples, C, H, W) and be scaled to the range [0, 255].
Note
using this metric with the default feature extractor requires that
torchvision
is installed. Either install aspip install torchmetrics[image]
orpip install torchvision
As input to
forward
andupdate
the metric accepts the following inputgenerator
(Module
): Generator model, with specific requirements. See above.
As output of forward and compute the metric returns the following output
ppl_mean
(Tensor
): float scalar tensor with mean PPL value over distancesppl_std
(Tensor
): float scalar tensor with std PPL value over distancesppl_raw
(Tensor
): float scalar tensor with raw PPL distances
- Parameters:
num_samples (
int
) – Number of samples to use for the PPL computation.conditional (
bool
) – Whether the generator is conditional or not (i.e. whether it takes labels as input).batch_size (
int
) – Batch size to use for the PPL computation.interpolation_method (
Literal
['lerp'
,'slerp_any'
,'slerp_unit'
]) – Interpolation method to use. Choose from ‘lerp’, ‘slerp_any’, ‘slerp_unit’.epsilon (
float
) – Spacing between the points on the path between latent points.resize (
Optional
[int
]) – Resize images to this size before computing the similarity between generated images.lower_discard (
Optional
[float
]) – Lower quantile to discard from the distances, before computing the mean and standard deviation.upper_discard (
Optional
[float
]) – Upper quantile to discard from the distances, before computing the mean and standard deviation.sim_net (
Union
[Module
,Literal
['alex'
,'vgg'
,'squeeze'
]]) – Similarity network to use. Can be a nn.Module or one of ‘alex’, ‘vgg’, ‘squeeze’, where the three latter options correspond to the pretrained networks from the LPIPS paper.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
torch-fidelity
is not installed.ValueError – If
num_samples
is not a positive integer.ValueError – If conditional is not a boolean.
ValueError – If
batch_size
is not a positive integer.ValueError – If
interpolation_method
is not one of ‘lerp’, ‘slerp_any’, ‘slerp_unit’.ValueError – If
epsilon
is not a positive float.ValueError – If
resize
is not a positive integer.ValueError – If
lower_discard
is not a float between 0 and 1 or None.ValueError – If
upper_discard
is not a float between 0 and 1 or None.
- Example::
>>> from torchmetrics.image import PerceptualPathLength >>> import torch >>> _ = torch.manual_seed(42) >>> class DummyGenerator(torch.nn.Module): ... def __init__(self, z_size) -> None: ... super().__init__() ... self.z_size = z_size ... self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid()) ... def forward(self, z): ... return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1) ... def sample(self, num_samples): ... return torch.randn(num_samples, self.z_size) >>> generator = DummyGenerator(2) >>> ppl = PerceptualPathLength(num_samples=10) >>> ppl(generator) (tensor(0.2371), tensor(0.1763), tensor([0.3502, 0.1362, 0.2535, 0.0902, 0.1784, 0.0769, 0.5871, 0.0691, 0.3921]))
- class torchmetrics.image.perceptual_path_length.GeneratorType(*args, **kwargs)[source]¶
Basic interface for a generator model.
Users can inherit from this class and implement their own generator model. The requirements are that the
sample
method is implemented and that thenum_classes
attribute is present whenconditional=True
metric.
Functional Interface¶
- torchmetrics.functional.image.perceptual_path_length.perceptual_path_length(generator, num_samples=10000, conditional=False, batch_size=64, interpolation_method='lerp', epsilon=0.0001, resize=64, lower_discard=0.01, upper_discard=0.99, sim_net='vgg', device='cpu')[source]¶
Computes the perceptual path length (PPL) of a generator model.
The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is defined as
\[PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]\]where \(G\) is the generator, \(I\) is the interpolation function, \(D\) is a similarity metric, \(z_1\) and \(z_2\) are two sets of latent points, and \(t\) is a parameter between 0 and 1. The metric thus works by interpolating between two sets of latent points, and measuring the similarity between the generated images. The expectation is approximated by sampling \(z_1\) and \(z_2\) from the generator, and averaging the calculated distanced. The similarity metric \(D\) is by default the LPIPS metric, but can be changed by setting the sim_net argument.
The provided generator model must have a sample method with signature sample(num_samples: int) -> Tensor where the returned tensor has shape (num_samples, z_size). If the generator is conditional, it must also have a num_classes attribute. The forward method of the generator must have signature forward(z: Tensor) -> Tensor if conditional=False, and forward(z: Tensor, labels: Tensor) -> Tensor if conditional=True. The returned tensor should have shape (num_samples, C, H, W) and be scaled to the range [0, 255].
- Parameters:
generator (
GeneratorType
) – Generator model, with specific requirements. See above.num_samples (
int
) – Number of samples to use for the PPL computation.conditional (
bool
) – Whether the generator is conditional or not (i.e. whether it takes labels as input).batch_size (
int
) – Batch size to use for the PPL computation.interpolation_method (
Literal
['lerp'
,'slerp_any'
,'slerp_unit'
]) – Interpolation method to use. Choose from ‘lerp’, ‘slerp_any’, ‘slerp_unit’.epsilon (
float
) – Spacing between the points on the path between latent points.resize (
Optional
[int
]) – Resize images to this size before computing the similarity between generated images.lower_discard (
Optional
[float
]) – Lower quantile to discard from the distances, before computing the mean and standard deviation.upper_discard (
Optional
[float
]) – Upper quantile to discard from the distances, before computing the mean and standard deviation.sim_net (
Union
[Module
,Literal
['alex'
,'vgg'
,'squeeze'
]]) – Similarity network to use. Can be a nn.Module or one of ‘alex’, ‘vgg’, ‘squeeze’, where the three latter options correspond to the pretrained networks from the LPIPS paper.device (
Union
[str
,device
]) – Device to use for the computation.
- Return type:
- Returns:
A tuple containing the mean, standard deviation and all distances.
- Example::
>>> from torchmetrics.functional.image import perceptual_path_length >>> import torch >>> _ = torch.manual_seed(42) >>> class DummyGenerator(torch.nn.Module): ... def __init__(self, z_size) -> None: ... super().__init__() ... self.z_size = z_size ... self.model = torch.nn.Sequential(torch.nn.Linear(z_size, 3*128*128), torch.nn.Sigmoid()) ... def forward(self, z): ... return 255 * (self.model(z).reshape(-1, 3, 128, 128) + 1) ... def sample(self, num_samples): ... return torch.randn(num_samples, self.z_size) >>> generator = DummyGenerator(2) >>> perceptual_path_length(generator, num_samples=10) (tensor(0.1945), tensor(0.1222), tensor([0.0990, 0.4173, 0.1628, 0.3573, 0.1875, 0.0335, 0.1095, 0.1887, 0.1953]))
Relative Average Spectral Error (RASE)¶
Module Interface¶
- class torchmetrics.image.RelativeAverageSpectralError(window_size=8, **kwargs)[source]¶
Computes Relative Average Spectral Error (RASE) (RelativeAverageSpectralError).
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,C,H,W)
target
(Tensor
): Ground truth values of shape(N,C,H,W)
As output of forward and compute the metric returns the following output
rase
(Tensor
): returns float scalar tensor with average RASE value over sample
- Parameters:
window_size (
int
) – Sliding window used for rmse calculationkwargs (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
Relative Average Spectral Error (RASE)
Example
>>> import torch >>> from torchmetrics.image import RelativeAverageSpectralError >>> g = torch.manual_seed(22) >>> preds = torch.rand(4, 3, 16, 16) >>> target = torch.rand(4, 3, 16, 16) >>> rase = RelativeAverageSpectralError() >>> rase(preds, target) tensor(5114.6641)
- Raises:
ValueError – If
window_size
is not a positive integer.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import RelativeAverageSpectralError >>> metric = RelativeAverageSpectralError() >>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import RelativeAverageSpectralError >>> metric = RelativeAverageSpectralError() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.relative_average_spectral_error(preds, target, window_size=8)[source]¶
Compute Relative Average Spectral Error (RASE) (RelativeAverageSpectralError).
- Parameters:
- Return type:
- Returns:
Relative Average Spectral Error (RASE)
Example
>>> from torchmetrics.functional.image import relative_average_spectral_error >>> g = torch.manual_seed(22) >>> preds = torch.rand(4, 3, 16, 16) >>> target = torch.rand(4, 3, 16, 16) >>> relative_average_spectral_error(preds, target) tensor(5114.6641)
- Raises:
ValueError – If
window_size
is not a positive integer.
Root Mean Squared Error Using Sliding Window¶
Module Interface¶
- class torchmetrics.image.RootMeanSquaredErrorUsingSlidingWindow(window_size=8, **kwargs)[source]¶
Computes Root Mean Squared Error (RMSE) using sliding window.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,C,H,W)
target
(Tensor
): Ground truth values of shape(N,C,H,W)
As output of forward and compute the metric returns the following output
rmse_sw
(Tensor
): returns float scalar tensor with average RMSE-SW value over sample
- Parameters:
window_size (
int
) – Sliding window used for rmse calculationkwargs (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow >>> g = torch.manual_seed(22) >>> preds = torch.rand(4, 3, 16, 16) >>> target = torch.rand(4, 3, 16, 16) >>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow() >>> rmse_sw(preds, target) tensor(0.3999)
- Raises:
ValueError – If
window_size
is not a positive integer.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow >>> metric = RootMeanSquaredErrorUsingSlidingWindow() >>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow >>> metric = RootMeanSquaredErrorUsingSlidingWindow() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.root_mean_squared_error_using_sliding_window(preds, target, window_size=8, return_rmse_map=False)[source]¶
Compute Root Mean Squared Error (RMSE) using sliding window.
- Parameters:
- Return type:
- Returns:
RMSE using sliding window (Optionally) RMSE map
Example
>>> from torchmetrics.functional.image import root_mean_squared_error_using_sliding_window >>> g = torch.manual_seed(22) >>> preds = torch.rand(4, 3, 16, 16) >>> target = torch.rand(4, 3, 16, 16) >>> root_mean_squared_error_using_sliding_window(preds, target) tensor(0.3999)
- Raises:
ValueError – If
window_size
is not a positive integer.
Spectral Angle Mapper¶
Module Interface¶
- class torchmetrics.image.SpectralAngleMapper(reduction='elementwise_mean', **kwargs)[source]¶
Spectral Angle Mapper determines the spectral similarity between image spectra and reference spectra.
It works by calculating the angle between the spectra, where small angles between indicate high similarity and high angles indicate low similarity.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,C,H,W)
target
(Tensor
): Ground truth values of shape(N,C,H,W)
As output of forward and compute the metric returns the following output
sam
(Tensor
): ifreduction!='none'
returns float scalar tensor with average SAM value over sample else returns tensor of shape(N,)
with SAM values per sample
- Parameters:
reduction (
Literal
['elementwise_mean'
,'sum'
,'none'
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
orNone
: no reduction will be applied
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
Tensor with SpectralAngleMapper score
Example
>>> import torch >>> from torchmetrics.image import SpectralAngleMapper >>> gen = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 16, 16], generator=gen) >>> target = torch.rand([16, 3, 16, 16], generator=gen) >>> sam = SpectralAngleMapper() >>> sam(preds, target) tensor(0.5914)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting single value >>> import torch >>> from torchmetrics.image import SpectralAngleMapper >>> gen = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 16, 16], generator=gen) >>> target = torch.rand([16, 3, 16, 16], generator=gen) >>> metric = SpectralAngleMapper() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import SpectralAngleMapper >>> gen = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 16, 16], generator=gen) >>> target = torch.rand([16, 3, 16, 16], generator=gen) >>> metric = SpectralAngleMapper() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.spectral_angle_mapper(preds, target, reduction='elementwise_mean')[source]¶
Universal Spectral Angle Mapper.
- Parameters:
- Return type:
- Returns:
Tensor with Spectral Angle Mapper score
- Raises:
TypeError – If
preds
andtarget
don’t have the same data type.ValueError – If
preds
andtarget
don’t haveBxCxHxW shape
.
Example
>>> from torchmetrics.functional.image import spectral_angle_mapper >>> gen = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 16, 16], generator=gen) >>> target = torch.rand([16, 3, 16, 16], generator=gen) >>> spectral_angle_mapper(preds, target) tensor(0.5914)
References
[1] Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, “Discrimination among semi-arid landscape endmembers using the Spectral Angle Mapper (SAM) algorithm” in PL, Summaries of the Third Annual JPL Airborne Geoscience Workshop, vol. 1, June 1, 1992.
Spectral Distortion Index¶
Module Interface¶
- class torchmetrics.image.SpectralDistortionIndex(p=1, reduction='elementwise_mean', **kwargs)[source]¶
Compute Spectral Distortion Index (SpectralDistortionIndex) also now as D_lambda.
The metric is used to compare the spectral distortion between two images.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Low resolution multispectral image of shape(N,C,H,W)
target``(:class:`~torch.Tensor`): High resolution fused image of shape ``(N,C,H,W)
As output of forward and compute the metric returns the following output
sdi
(Tensor
): ifreduction!='none'
returns float scalar tensor with average SDI value over sample else returns tensor of shape(N,)
with SDI values per sample
- Parameters:
p (
int
) – Large spectral differencesreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import SpectralDistortionIndex >>> preds = torch.rand([16, 3, 16, 16]) >>> target = torch.rand([16, 3, 16, 16]) >>> sdi = SpectralDistortionIndex() >>> sdi(preds, target) tensor(0.0234)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import SpectralDistortionIndex >>> preds = torch.rand([16, 3, 16, 16]) >>> target = torch.rand([16, 3, 16, 16]) >>> metric = SpectralDistortionIndex() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import SpectralDistortionIndex >>> preds = torch.rand([16, 3, 16, 16]) >>> target = torch.rand([16, 3, 16, 16]) >>> metric = SpectralDistortionIndex() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.spectral_distortion_index(preds, target, p=1, reduction='elementwise_mean')[source]¶
Calculate Spectral Distortion Index (SpectralDistortionIndex) also known as D_lambda.
Metric is used to compare the spectral distortion between two images.
- Parameters:
preds (
Tensor
) – Low resolution multispectral imagetarget (
Tensor
) – High resolution fused imagep (
int
) – Large spectral differencesreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
: no reduction will be applied
- Return type:
- Returns:
Tensor with SpectralDistortionIndex score
- Raises:
TypeError – If
preds
andtarget
don’t have the same data type.ValueError – If
preds
andtarget
don’t haveBxCxHxW shape
.ValueError – If
p
is not a positive integer.
Example
>>> from torchmetrics.functional.image import spectral_distortion_index >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 16, 16]) >>> target = torch.rand([16, 3, 16, 16]) >>> spectral_distortion_index(preds, target) tensor(0.0234)
Structural Similarity Index Measure (SSIM)¶
Module Interface¶
- class torchmetrics.image.StructuralSimilarityIndexMeasure(gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, return_full_image=False, return_contrast_sensitivity=False, **kwargs)[source]¶
Compute Structural Similarity Index Measure (SSIM).
As input to
forward
andupdate
the metric accepts the following inputAs output of forward and compute the metric returns the following output
ssim
(Tensor
): ifreduction!='none'
returns float scalar tensor with average SSIM value over sample else returns tensor of shape(N,)
with SSIM values per sample
- Parameters:
preds – estimated image
target – ground truth image
gaussian_kernel (
bool
) – IfTrue
(default), a gaussian kernel is used, ifFalse
a uniform kernel is usedsigma (
Union
[float
,Sequence
[float
]]) – Standard deviation of the gaussian kernel, anisotropic kernels are possible. Ignored if a uniform kernel is usedkernel_size (
Union
[int
,Sequence
[int
]]) – the size of the uniform kernel, anisotropic kernels are possible. Ignored if a Gaussian kernel is usedreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over individual batch scores
'elementwise_mean'
: takes the mean'sum'
: takes the sum'none'
orNone
: no reduction will be applied
data_range (
Union
[float
,Tuple
[float
,float
],None
]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.k1 (
float
) – Parameter of SSIM.k2 (
float
) – Parameter of SSIM.return_full_image (
bool
) – If true, the fullssim
image is returned as a second argument. Mutually exclusive withreturn_contrast_sensitivity
return_contrast_sensitivity (
bool
) – If true, the constant term is returned as a second argument. The luminance term can be obtained with luminance=ssim/contrast Mutually exclusive withreturn_full_image
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> from torchmetrics.image import StructuralSimilarityIndexMeasure >>> preds = torch.rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> ssim = StructuralSimilarityIndexMeasure(data_range=1.0) >>> ssim(preds, target) tensor(0.9219)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import StructuralSimilarityIndexMeasure >>> preds = torch.rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0) >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import StructuralSimilarityIndexMeasure >>> preds = torch.rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, return_full_image=False, return_contrast_sensitivity=False)[source]¶
Compute Structual Similarity Index Measure.
- Parameters:
preds (
Tensor
) – estimated imagetarget (
Tensor
) – ground truth imagegaussian_kernel (
bool
) – If true (default), a gaussian kernel is used, if false a uniform kernel is usedsigma (
Union
[float
,Sequence
[float
]]) – Standard deviation of the gaussian kernel, anisotropic kernels are possible. Ignored if a uniform kernel is usedkernel_size (
Union
[int
,Sequence
[int
]]) – the size of the uniform kernel, anisotropic kernels are possible. Ignored if a Gaussian kernel is usedreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean'sum'
: takes the sum'none'
orNone
: no reduction will be applied
data_range (
Union
[float
,Tuple
[float
,float
],None
]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.k1 (
float
) – Parameter of SSIM.k2 (
float
) – Parameter of SSIM.return_full_image (
bool
) – If true, the fullssim
image is returned as a second argument. Mutually exclusive withreturn_contrast_sensitivity
return_contrast_sensitivity (
bool
) – If true, the constant term is returned as a second argument. The luminance term can be obtained with luminance=ssim/contrast Mutually exclusive withreturn_full_image
- Return type:
- Returns:
Tensor with SSIM score
- Raises:
TypeError – If
preds
andtarget
don’t have the same data type.ValueError – If
preds
andtarget
don’t haveBxCxHxW shape
.ValueError – If the length of
kernel_size
orsigma
is not2
.ValueError – If one of the elements of
kernel_size
is not anodd positive number
.ValueError – If one of the elements of
sigma
is not apositive number
.
Example
>>> from torchmetrics.functional.image import structural_similarity_index_measure >>> preds = torch.rand([3, 3, 256, 256]) >>> target = preds * 0.75 >>> structural_similarity_index_measure(preds, target) tensor(0.9219)
Total Variation (TV)¶
Module Interface¶
- class torchmetrics.image.TotalVariation(reduction='sum', **kwargs)[source]¶
Compute Total Variation loss (TV).
As input to
forward
andupdate
the metric accepts the following inputimg
(Tensor
): A tensor of shape(N, C, H, W)
consisting of images
As output of forward and compute the metric returns the following output
sdi
(Tensor
): ifreduction!='none'
returns float scalar tensor with average TV value over sample else returns tensor of shape(N,)
with TV values per sample
- Parameters:
reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
]]) –a method to reduce metric score over samples
'mean'
: takes the mean over samples'sum'
: takes the sum over samplesNone
or'none'
: return the score per sample
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
reduction
is not one of'sum'
,'mean'
,'none'
orNone
Example
>>> import torch >>> from torchmetrics.image import TotalVariation >>> _ = torch.manual_seed(42) >>> tv = TotalVariation() >>> img = torch.rand(5, 3, 28, 28) >>> tv(img) tensor(7546.8018)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import TotalVariation >>> metric = TotalVariation() >>> metric.update(torch.rand(5, 3, 28, 28)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import TotalVariation >>> metric = TotalVariation() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(5, 3, 28, 28))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.total_variation(img, reduction='sum')[source]¶
Compute total variation loss.
- Parameters:
- Return type:
- Returns:
A loss scalar value containing the total variation
- Raises:
ValueError – If
reduction
is not one of'sum'
,'mean'
,'none'
orNone
RuntimeError – If
img
is not 4D tensor
Example
>>> import torch >>> from torchmetrics.functional.image import total_variation >>> _ = torch.manual_seed(42) >>> img = torch.rand(5, 3, 28, 28) >>> total_variation(img) tensor(7546.8018)
Universal Image Quality Index¶
Module Interface¶
- class torchmetrics.image.UniversalImageQualityIndex(kernel_size=(11, 11), sigma=(1.5, 1.5), reduction='elementwise_mean', **kwargs)[source]¶
Compute Universal Image Quality Index (UniversalImageQualityIndex).
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,C,H,W)
target
(Tensor
): Ground truth values of shape(N,C,H,W)
As output of forward and compute the metric returns the following output
uiqi
(Tensor
): ifreduction!='none'
returns float scalar tensor with average UIQI value over sample else returns tensor of shape(N,)
with UIQI values per sample
- Parameters:
sigma (
Sequence
[float
]) – Standard deviation of the gaussian kernelreduction (
Literal
['elementwise_mean'
,'sum'
,'none'
,None
]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
orNone
: no reduction will be applied
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Returns:
Tensor with UniversalImageQualityIndex score
Example
>>> import torch >>> from torchmetrics.image import UniversalImageQualityIndex >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> uqi = UniversalImageQualityIndex() >>> uqi(preds, target) tensor(0.9216)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image import UniversalImageQualityIndex >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> metric = UniversalImageQualityIndex() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image import UniversalImageQualityIndex >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> metric = UniversalImageQualityIndex() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.image.universal_image_quality_index(preds, target, kernel_size=(11, 11), sigma=(1.5, 1.5), reduction='elementwise_mean')[source]¶
Universal Image Quality Index.
- Parameters:
preds (
Tensor
) – estimated imagetarget (
Tensor
) – ground truth imagesigma (
Sequence
[float
]) – Standard deviation of the gaussian kernelreduction (
Optional
[Literal
['elementwise_mean'
,'sum'
,'none'
]]) –a method to reduce metric score over labels.
'elementwise_mean'
: takes the mean (default)'sum'
: takes the sum'none'
orNone
: no reduction will be applied
- Return type:
- Returns:
Tensor with UniversalImageQualityIndex score
- Raises:
TypeError – If
preds
andtarget
don’t have the same data type.ValueError – If
preds
andtarget
don’t haveBxCxHxW shape
.ValueError – If the length of
kernel_size
orsigma
is not2
.ValueError – If one of the elements of
kernel_size
is not anodd positive number
.ValueError – If one of the elements of
sigma
is not apositive number
.
Example
>>> from torchmetrics.functional.image import universal_image_quality_index >>> preds = torch.rand([16, 1, 16, 16]) >>> target = preds * 0.75 >>> universal_image_quality_index(preds, target) tensor(0.9216)
References
[1] Zhou Wang and A. C. Bovik, “A universal image quality index,” in IEEE Signal Processing Letters, vol. 9, no. 3, pp. 81-84, March 2002, doi: 10.1109/97.995823.
[2] Zhou Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: from error visibility to structural similarity,” in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, April 2004, doi: 10.1109/TIP.2003.819861.
Visual Information Fidelity (VIF)¶
Module Interface¶
- class torchmetrics.image.VisualInformationFidelity(sigma_n_sq=2.0, **kwargs)[source]¶
Compute Pixel Based Visual Information Fidelity (VIF).
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): Predictions from model of shape(N,C,H,W)
with H,W ≥ 41target
(Tensor
): Ground truth values of shape(N,C,H,W)
with H,W ≥ 41
As output of forward and compute the metric returns the following output
vif-p
(Tensor
): Tensor with vif-p score
- Parameters:
sigma_n_sq (
float
) – variance of the visual noisekwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.image import VisualInformationFidelity >>> preds = torch.randn([32, 3, 41, 41]) >>> target = torch.randn([32, 3, 41, 41]) >>> vif = VisualInformationFidelity() >>> vif(preds, target) tensor(0.0032)
Functional Interface¶
- torchmetrics.functional.image.visual_information_fidelity(preds, target, sigma_n_sq=2.0)[source]¶
Compute Pixel Based Visual Information Fidelity (VIF).
- Parameters:
- Return type:
- Returns:
Tensor with vif-p score
- Raises:
ValueError – If
data_range
is neither atuple
nor afloat
CLIP Image Quality Assessment (CLIP-IQA)¶
Module Interface¶
- class torchmetrics.multimodal.CLIPImageQualityAssessment(model_name_or_path='clip_iqa', data_range=1.0, prompts=('quality',), **kwargs)[source]
Calculates CLIP-IQA, that can be used to measure the visual content of images.
The metric is based on the CLIP model, which is a neural network trained on a variety of (image, text) pairs to be able to generate a vector representation of the image and the text that is similar if the image and text are semantically similar.
The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The promts always comes in pairs of “positive” and “negative” such as “Good photo.” and “Bad photo.”. By calculating the similartity between image embeddings and both the “positive” and “negative” prompt, the metric can determine which prompt the image is more similar to. The metric then returns the probability that the image is more similar to the first prompt than the second prompt.
- Build in promts are:
quality: “Good photo.” vs “Bad photo.”
brightness: “Bright photo.” vs “Dark photo.”
noisiness: “Clean photo.” vs “Noisy photo.”
colorfullness: “Colorful photo.” vs “Dull photo.”
sharpness: “Sharp photo.” vs “Blurry photo.”
contrast: “High contrast photo.” vs “Low contrast photo.”
complexity: “Complex photo.” vs “Simple photo.”
natural: “Natural photo.” vs “Synthetic photo.”
happy: “Happy photo.” vs “Sad photo.”
scary: “Scary photo.” vs “Peaceful photo.”
new: “New photo.” vs “Old photo.”
warm: “Warm photo.” vs “Cold photo.”
real: “Real photo.” vs “Abstract photo.”
beutiful: “Beautiful photo.” vs “Ugly photo.”
lonely: “Lonely photo.” vs “Sociable photo.”
relaxing: “Relaxing photo.” vs “Stressful photo.”
As input to
forward
andupdate
the metric accepts the following inputimages
(Tensor
): tensor with images feed to the feature extractor with shape(N,C,H,W)
As output of forward and compute the metric returns the following output
clip_iqa
(Tensor
or dict of tensors): tensor with the CLIP-IQA score. If a single prompt is provided, a single tensor with shape(N,)
is returned. If a list of prompts is provided, a dict of tensors is returned with the prompt as key and the tensor with shape(N,)
as value.
- Parameters:
model_name_or_path (
Literal
['clip_iqa'
,'openai/clip-vit-base-patch16'
,'openai/clip-vit-base-patch32'
,'openai/clip-vit-large-patch14-336'
,'openai/clip-vit-large-patch14'
]) –string indicating the version of the CLIP model to use. Available models are:
”clip_iqa”, model corresponding to the CLIP-IQA paper.
”openai/clip-vit-base-patch16”
”openai/clip-vit-base-patch32”
”openai/clip-vit-large-patch14-336”
”openai/clip-vit-large-patch14”
data_range (
Union
[int
,float
]) – The maximum value of the input tensor. For example, if the input images are in range [0, 255], data_range should be 255. The images are normalized by this value.prompts (
Tuple
[Union
[str
,Tuple
[str
,str
]]]) – A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Note
If using the default clip_iqa model, the package piq must be installed. Either install with pip install piq or pip install torchmetrics[image].
- Raises:
ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0
ValueError – If prompts is a tuple and it is not of length 2
ValueError – If prompts is a string and it is not one of the available prompts
ValueError – If prompts is a list of strings and not all strings are one of the available prompts
- Example::
Single prompt:
>>> from torchmetrics.multimodal import CLIPImageQualityAssessment >>> import torch >>> _ = torch.manual_seed(42) >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment() >>> metric(imgs) tensor([0.8894, 0.8902])
- Example::
Multiple prompts:
>>> from torchmetrics.multimodal import CLIPImageQualityAssessment >>> import torch >>> _ = torch.manual_seed(42) >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness")) >>> metric(imgs) {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])}
- Example::
Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.
>>> from torchmetrics.multimodal import CLIPImageQualityAssessment >>> import torch >>> _ = torch.manual_seed(42) >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness")) >>> metric(imgs) {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])}
- plot(val=None, ax=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment >>> metric = CLIPImageQualityAssessment() >>> metric.update(torch.rand(1, 3, 224, 224)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment >>> metric = CLIPImageQualityAssessment() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(1, 3, 224, 224))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.multimodal.clip_image_quality_assessment(images, model_name_or_path='clip_iqa', data_range=1.0, prompts=('quality',))[source]
Calculates CLIP-IQA, that can be used to measure the visual content of images.
The metric is based on the CLIP model, which is a neural network trained on a variety of (image, text) pairs to be able to generate a vector representation of the image and the text that is similar if the image and text are semantically similar.
The metric works by calculating the cosine similarity between user provided images and pre-defined promts. The prompts always come in pairs of “positive” and “negative” such as “Good photo.” and “Bad photo.”. By calculating the similartity between image embeddings and both the “positive” and “negative” prompt, the metric can determine which prompt the image is more similar to. The metric then returns the probability that the image is more similar to the first prompt than the second prompt.
- Build in promts are:
quality: “Good photo.” vs “Bad photo.”
brightness: “Bright photo.” vs “Dark photo.”
noisiness: “Clean photo.” vs “Noisy photo.”
colorfullness: “Colorful photo.” vs “Dull photo.”
sharpness: “Sharp photo.” vs “Blurry photo.”
contrast: “High contrast photo.” vs “Low contrast photo.”
complexity: “Complex photo.” vs “Simple photo.”
natural: “Natural photo.” vs “Synthetic photo.”
happy: “Happy photo.” vs “Sad photo.”
scary: “Scary photo.” vs “Peaceful photo.”
new: “New photo.” vs “Old photo.”
warm: “Warm photo.” vs “Cold photo.”
real: “Real photo.” vs “Abstract photo.”
beutiful: “Beautiful photo.” vs “Ugly photo.”
lonely: “Lonely photo.” vs “Sociable photo.”
relaxing: “Relaxing photo.” vs “Stressful photo.”
- Parameters:
images (
Tensor
) – Either a single[N, C, H, W]
tensor or a list of[C, H, W]
tensorsmodel_name_or_path (
Literal
['clip_iqa'
,'openai/clip-vit-base-patch16'
,'openai/clip-vit-base-patch32'
,'openai/clip-vit-large-patch14-336'
,'openai/clip-vit-large-patch14'
]) – string indicating the version of the CLIP model to use. By default this argument is set toclip_iqa
which corresponds to the model used in the original paper. Other availble models are “openai/clip-vit-base-patch16”, “openai/clip-vit-base-patch32”, “openai/clip-vit-large-patch14-336” and “openai/clip-vit-large-patch14”data_range (
Union
[int
,float
]) – The maximum value of the input tensor. For example, if the input images are in range [0, 255], data_range should be 255. The images are normalized by this value.prompts (
Tuple
[Union
[str
,Tuple
[str
,str
]]]) – A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one of the availble prompts (see above). Else the input is expected to be a tuple, where each element can be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the availble prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt.
Note
If using the default clip_iqa model, the package piq must be installed. Either install with pip install piq or pip install torchmetrics[multimodal].
- Return type:
- Returns:
A tensor of shape
(N,)
if a single promts is provided. If a list of promts is provided, a dictionary of with the promts as keys and tensors of shape(N,)
as values.- Raises:
ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0
ValueError – If not all images have format [C, H, W]
ValueError – If promts is a tuple and it is not of length 2
ValueError – If promts is a string and it is not one of the available promts
ValueError – If promts is a list of strings and not all strings are one of the available promts
- Example::
Single promt:
>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment >>> import torch >>> _ = torch.manual_seed(42) >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=("quality",)) tensor([0.8894, 0.8902])
- Example::
Multiple promts:
>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment >>> import torch >>> _ = torch.manual_seed(42) >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness")) {'quality': tensor([0.8894, 0.8902]), 'brightness': tensor([0.5507, 0.5208])}
- Example::
Custom promts. Must always be a tuple of length 2, with a positive and negative prompt.
>>> from torchmetrics.functional.multimodal import clip_image_quality_assessment >>> import torch >>> _ = torch.manual_seed(42) >>> imgs = torch.randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness")) {'user_defined_0': tensor([0.9652, 0.9629]), 'brightness': tensor([0.5507, 0.5208])}
CLIP Score¶
Module Interface¶
- class torchmetrics.multimodal.clip_score.CLIPScore(model_name_or_path='openai/clip-vit-large-patch14', **kwargs)[source]¶
Calculates CLIP Score which is a text-to-image similarity metric.
CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for an image and the actual content of the image. It has been found to be highly correlated with human judgement. The metric is defined as:
\[\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)\]which corresponds to the cosine similarity between visual CLIP embedding \(E_i\) for an image \(i\) and textual CLIP embedding \(E_C\) for an caption \(C\). The score is bound between 0 and 100 and the closer to 100 the better.
Note
Metric is not scriptable
As input to
forward
andupdate
the metric accepts the following inputimages
(Tensor
or list of tensors): tensor with images feed to the feature extractor with. Ifa single tensor it should have shape
(N, C, H, W)
. If a list of tensors, each tensor should have shape(C, H, W)
.C
is the number of channels,H
andW
are the height and width of the image.
text
(str
orlist
ofstr
): text to compare with the images, one for each image.
As output of forward and compute the metric returns the following output
clip_score
(Tensor
): float scalar tensor with mean CLIP score over samples
- Parameters:
model_name_or_path (
Literal
['openai/clip-vit-base-patch16'
,'openai/clip-vit-base-patch32'
,'openai/clip-vit-large-patch14-336'
,'openai/clip-vit-large-patch14'
]) –string indicating the version of the CLIP model to use. Available models are:
”openai/clip-vit-base-patch16”
”openai/clip-vit-base-patch32”
”openai/clip-vit-large-patch14-336”
”openai/clip-vit-large-patch14”
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0
Example
>>> import torch >>> from torchmetrics.multimodal.clip_score import CLIPScore >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") >>> score = metric(torch.randint(255, (3, 224, 224), generator=torch.manual_seed(42)), "a photo of a cat") >>> score.detach() tensor(24.4255)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.multimodal.clip_score import CLIPScore >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") >>> metric.update(torch.randint(255, (3, 224, 224)), "a photo of a cat") >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.multimodal.clip_score import CLIPScore >>> metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16") >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randint(255, (3, 224, 224)), "a photo of a cat")) >>> fig_, ax_ = metric.plot(values)
- update(images, text)[source]¶
Update CLIP score on a batch of images and text.
- Parameters:
- Raises:
ValueError – If not all images have format [C, H, W]
ValueError – If the number of images and captions do not match
- Return type:
Functional Interface¶
- torchmetrics.functional.multimodal.clip_score.clip_score(images, text, model_name_or_path='openai/clip-vit-large-patch14')[source]¶
Calculate CLIP Score which is a text-to-image similarity metric.
CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for an image and the actual content of the image. It has been found to be highly correlated with human judgement. The metric is defined as:
\[\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)\]which corresponds to the cosine similarity between visual CLIP embedding \(E_i\) for an image \(i\) and textual CLIP embedding \(E_C\) for an caption \(C\). The score is bound between 0 and 100 and the closer to 100 the better.
Note
Metric is not scriptable
- Parameters:
images (
Union
[Tensor
,List
[Tensor
]]) – Either a single [N, C, H, W] tensor or a list of [C, H, W] tensorstext (
Union
[str
,List
[str
]]) – Either a single caption or a list of captionsmodel_name_or_path (
Literal
['openai/clip-vit-base-patch16'
,'openai/clip-vit-base-patch32'
,'openai/clip-vit-large-patch14-336'
,'openai/clip-vit-large-patch14'
]) – string indicating the version of the CLIP model to use. Available models are “openai/clip-vit-base-patch16”, “openai/clip-vit-base-patch32”, “openai/clip-vit-large-patch14-336” and “openai/clip-vit-large-patch14”,
- Raises:
ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0
ValueError – If not all images have format [C, H, W]
ValueError – If the number of images and captions do not match
- Return type:
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.functional.multimodal import clip_score >>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16") >>> score.detach() tensor(24.4255)
Cramer’s V¶
Module Interface¶
- class torchmetrics.nominal.CramersV(num_classes, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]¶
Compute Cramer’s V statistic measuring the association between two categorical (nominal) data series.
\[V = \sqrt{\frac{\chi^2 / n}{\min(r - 1, k - 1)}}\]where
\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in
preds
andtarget
, respectively. Cramer’s V is a symmetric coefficient, i.e. \(V(preds, target) = V(target, preds)\), so order of input arguments does not matter. The output values lies in [0, 1] with 1 meaning the perfect association.As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the first data series with shape(batch_size,)
or(batch_size, num_classes)
, respectively.target
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the second data series with shape(batch_size,)
or(batch_size, num_classes)
, respectively.
As output of
forward
andcompute
the metric returns the following output:cramers_v
(Tensor
): Scalar tensor containing the Cramer’s V statistic.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesbias_correction (
bool
) – Indication of whether to use bias correction.nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If nan_strategy is not one of ‘replace’ and ‘drop’
ValueError – If nan_strategy is equal to ‘replace’ and nan_replace_value is not an int or float
Example:
>>> from torchmetrics.nominal import CramersV >>> _ = torch.manual_seed(42) >>> preds = torch.randint(0, 4, (100,)) >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) >>> cramers_v = CramersV(num_classes=5) >>> cramers_v(preds, target) tensor(0.5284)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.nominal import CramersV >>> metric = CramersV(num_classes=5) >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.nominal import CramersV >>> metric = CramersV(num_classes=5) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.nominal.cramers_v(preds, target, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Cramer’s V statistic measuring the association between two categorical (nominal) data series.
\[V = \sqrt{\frac{\chi^2 / n}{\min(r - 1, k - 1)}}\]where
\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in
preds
andtarget
, respectively.Cramer’s V is a symmetric coefficient, i.e. \(V(preds, target) = V(target, preds)\).
The output values lies in [0, 1] with 1 meaning the perfect association.
- Parameters:
preds (
Tensor
) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)target (
Tensor
) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)bias_correction (
bool
) – Indication of whether to use bias correction.nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Cramer’s V statistic
Example
>>> from torchmetrics.functional.nominal import cramers_v >>> _ = torch.manual_seed(42) >>> preds = torch.randint(0, 4, (100,)) >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) >>> cramers_v(preds, target) tensor(0.5284)
cramers_v_matrix¶
- torchmetrics.functional.nominal.cramers_v_matrix(matrix, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Cramer’s V statistic between a set of multiple variables.
This can serve as a convenient tool to compute Cramer’s V statistic for analyses of correlation between categorical variables in your dataset.
- Parameters:
matrix (
Tensor
) – A tensor of categorical (nominal) data, where: - rows represent a number of data points - columns represent a number of categorical (nominal) featuresbias_correction (
bool
) – Indication of whether to use bias correction.nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Cramer’s V statistic for a dataset of categorical variables
Example
>>> from torchmetrics.functional.nominal import cramers_v_matrix >>> _ = torch.manual_seed(42) >>> matrix = torch.randint(0, 4, (200, 5)) >>> cramers_v_matrix(matrix) tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337], [0.0637, 1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000, 0.0649], [0.0542, 0.0000, 0.0000, 1.0000, 0.1100], [0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])
Fleiss Kappa¶
Module Interface¶
- class torchmetrics.nominal.FleissKappa(mode='counts', **kwargs)[source]¶
Calculatees Fleiss kappa a statistical measure for inter agreement between raters.
\[\kappa = \frac{\bar{p} - \bar{p_e}}{1 - \bar{p_e}}\]where \(\bar{p}\) is the mean of the agreement probability over all raters and \(\bar{p_e}\) is the mean agreement probability over all raters if they were randomly assigned. If the raters are in complete agreement then the score 1 is returned, if there is no agreement among the raters (other than what would be expected by chance) then a score smaller than 0 is returned.
As input to
forward
andupdate
the metric accepts the following input:ratings
(Tensor
): Ratings of shape[n_samples, n_categories]
or[n_samples, n_categories, n_raters]
depedenent onmode
. Ifmode
iscounts
,ratings
must be integer and contain the number of raters that chose each category. Ifmode
isprobs
,ratings
must be floating point and contain the probability/logits that each rater chose each category.
As output of
forward
andcompute
the metric returns the following output:fleiss_k
(Tensor
): A float scalar tensor with the calculated Fleiss’ kappa score.
- Parameters:
mode (
Literal
['counts'
,'probs'
]) – Whether ratings will be provided as counts or probabilities.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> # Ratings are provided as counts >>> import torch >>> from torchmetrics.nominal import FleissKappa >>> _ = torch.manual_seed(42) >>> ratings = torch.randint(0, 10, size=(100, 5)).long() # 100 samples, 5 categories, 10 raters >>> metric = FleissKappa(mode='counts') >>> metric(ratings) tensor(0.0089)
Example
>>> # Ratings are provided as probabilities >>> import torch >>> from torchmetrics.nominal import FleissKappa >>> _ = torch.manual_seed(42) >>> ratings = torch.randn(100, 5, 10).softmax(dim=1) # 100 samples, 5 categories, 10 raters >>> metric = FleissKappa(mode='probs') >>> metric(ratings) tensor(-0.0105)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.nominal import FleissKappa >>> metric = FleissKappa(mode="probs") >>> metric.update(torch.randn(100, 5, 10).softmax(dim=1)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.nominal import FleissKappa >>> metric = FleissKappa(mode="probs") >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randn(100, 5, 10).softmax(dim=1))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.nominal.fleiss_kappa(ratings, mode='counts')[source]¶
Calculatees Fleiss kappa a statistical measure for inter agreement between raters.
\[\kappa = \frac{\bar{p} - \bar{p_e}}{1 - \bar{p_e}}\]where \(\bar{p}\) is the mean of the agreement probability over all raters and \(\bar{p_e}\) is the mean agreement probability over all raters if they were randomly assigned. If the raters are in complete agreement then the score 1 is returned, if there is no agreement among the raters (other than what would be expected by chance) then a score smaller than 0 is returned.
- Parameters:
ratings (
Tensor
) – Ratings of shape [n_samples, n_categories] or [n_samples, n_categories, n_raters] depedenent on mode. If mode is counts, ratings must be integer and contain the number of raters that chose each category. If mode is probs, ratings must be floating point and contain the probability/logits that each rater chose each category.mode (
Literal
['counts'
,'probs'
]) – Whether ratings will be provided as counts or probabilities.
- Return type:
Example
>>> # Ratings are provided as counts >>> import torch >>> from torchmetrics.functional.nominal import fleiss_kappa >>> _ = torch.manual_seed(42) >>> ratings = torch.randint(0, 10, size=(100, 5)).long() # 100 samples, 5 categories, 10 raters >>> fleiss_kappa(ratings) tensor(0.0089)
Example
>>> # Ratings are provided as probabilities >>> import torch >>> from torchmetrics.functional.nominal import fleiss_kappa >>> _ = torch.manual_seed(42) >>> ratings = torch.randn(100, 5, 10).softmax(dim=1) # 100 samples, 5 categories, 10 raters >>> fleiss_kappa(ratings, mode='probs') tensor(-0.0105)
Pearson’s Contingency Coefficient¶
Module Interface¶
- class torchmetrics.nominal.PearsonsContingencyCoefficient(num_classes, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]¶
Compute Pearson’s Contingency Coefficient statistic.
This metric measures the association between two categorical (nominal) data series.
\[Pearson = \sqrt{\frac{\chi^2 / n}{1 + \chi^2 / n}}\]where
\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in
preds
andtarget
, respectively. Pearson’s Contingency Coefficient is a symmetric coefficient, i.e. \(Pearson(preds, target) = Pearson(target, preds)\), so order of input arguments does not matter. The output values lies in [0, 1] with 1 meaning the perfect association.As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the first data series with shape(batch_size,)
or(batch_size, num_classes)
, respectively.target
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the second data series with shape(batch_size,)
or(batch_size, num_classes)
, respectively.
As output of
forward
andcompute
the metric returns the following output:pearsons_cc
(Tensor
): Scalar tensor containing the Pearsons Contingency Coefficient statistic.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesnan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If nan_strategy is not one of ‘replace’ and ‘drop’
ValueError – If nan_strategy is equal to ‘replace’ and nan_replace_value is not an int or float
Example:
>>> from torchmetrics.nominal import PearsonsContingencyCoefficient >>> _ = torch.manual_seed(42) >>> preds = torch.randint(0, 4, (100,)) >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) >>> pearsons_contingency_coefficient = PearsonsContingencyCoefficient(num_classes=5) >>> pearsons_contingency_coefficient(preds, target) tensor(0.6948)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.nominal import PearsonsContingencyCoefficient >>> metric = PearsonsContingencyCoefficient(num_classes=5) >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.nominal import PearsonsContingencyCoefficient >>> metric = PearsonsContingencyCoefficient(num_classes=5) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.nominal.pearsons_contingency_coefficient(preds, target, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Pearson’s Contingency Coefficient for measuring the association between two categorical data series.
\[Pearson = \sqrt{\frac{\chi^2 / n}{1 + \chi^2 / n}}\]where
\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in
preds
andtarget
, respectively.Pearson’s Contingency Coefficient is a symmetric coefficient, i.e. \(Pearson(preds, target) = Pearson(target, preds)\).
The output values lies in [0, 1] with 1 meaning the perfect association.
- Parameters:
preds (
Tensor
) –1D or 2D tensor of categorical (nominal) data:
1D shape: (batch_size,)
2D shape: (batch_size, num_classes)
target (
Tensor
) –1D or 2D tensor of categorical (nominal) data:
1D shape: (batch_size,)
2D shape: (batch_size, num_classes)
nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Pearson’s Contingency Coefficient
Example
>>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient >>> _ = torch.manual_seed(42) >>> preds = torch.randint(0, 4, (100,)) >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) >>> pearsons_contingency_coefficient(preds, target) tensor(0.6948)
pearsons_contingency_coefficient_matrix¶
- torchmetrics.functional.nominal.pearsons_contingency_coefficient_matrix(matrix, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Pearson’s Contingency Coefficient statistic between a set of multiple variables.
This can serve as a convenient tool to compute Pearson’s Contingency Coefficient for analyses of correlation between categorical variables in your dataset.
- Parameters:
matrix (
Tensor
) –A tensor of categorical (nominal) data, where:
rows represent a number of data points
columns represent a number of categorical (nominal) features
nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Pearson’s Contingency Coefficient statistic for a dataset of categorical variables
Example
>>> from torchmetrics.functional.nominal import pearsons_contingency_coefficient_matrix >>> _ = torch.manual_seed(42) >>> matrix = torch.randint(0, 4, (200, 5)) >>> pearsons_contingency_coefficient_matrix(matrix) tensor([[1.0000, 0.2326, 0.1959, 0.2262, 0.2989], [0.2326, 1.0000, 0.1386, 0.1895, 0.1329], [0.1959, 0.1386, 1.0000, 0.1840, 0.2335], [0.2262, 0.1895, 0.1840, 1.0000, 0.2737], [0.2989, 0.1329, 0.2335, 0.2737, 1.0000]])
Theil’s U¶
Module Interface¶
- class torchmetrics.nominal.TheilsU(num_classes, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]¶
Compute Theil’s U statistic measuring the association between two categorical (nominal) data series.
\[U(X|Y) = \frac{H(X) - H(X|Y)}{H(X)}\]where \(H(X)\) is entropy of variable \(X\) while \(H(X|Y)\) is the conditional entropy of \(X\) given \(Y\). It is also know as the Uncertainty Coefficient. Theils’s U is an asymmetric coefficient, i.e. \(TheilsU(preds, target) \neq TheilsU(target, preds)\), so the order of the inputs matters. The output values lies in [0, 1], where a 0 means y has no information about x while value 1 means y has complete information about x.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the first data series (called X in the above definition) with shape(batch_size,)
or(batch_size, num_classes)
, respectively.target
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the second data series (called Y in the above definition) with shape(batch_size,)
or(batch_size, num_classes)
, respectively.
As output of
forward
andcompute
the metric returns the following output:theils_u
(Tensor
): Scalar tensor containing the Theil’s U statistic.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesnan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example:
>>> from torchmetrics.nominal import TheilsU >>> _ = torch.manual_seed(42) >>> preds = torch.randint(10, (10,)) >>> target = torch.randint(10, (10,)) >>> metric = TheilsU(num_classes=10) >>> metric(preds, target) tensor(0.8530)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.nominal import TheilsU >>> metric = TheilsU(num_classes=10) >>> metric.update(torch.randint(10, (10,)), torch.randint(10, (10,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.nominal import TheilsU >>> metric = TheilsU(num_classes=10) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randint(10, (10,)), torch.randint(10, (10,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.nominal.theils_u(preds, target, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Theils Uncertainty coefficient statistic measuring the association between two nominal data series.
\[U(X|Y) = \frac{H(X) - H(X|Y)}{H(X)}\]where \(H(X)\) is entropy of variable \(X\) while \(H(X|Y)\) is the conditional entropy of \(X\) given \(Y\).
Theils’s U is an asymmetric coefficient, i.e. \(TheilsU(preds, target) \neq TheilsU(target, preds)\).
The output values lies in [0, 1]. 0 means y has no information about x while value 1 means y has complete information about x.
- Parameters:
preds (
Tensor
) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)target (
Tensor
) – 1D or 2D tensor of categorical (nominal) data - 1D shape: (batch_size,) - 2D shape: (batch_size, num_classes)nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Tensor containing Theil’s U statistic
Example
>>> from torchmetrics.functional.nominal import theils_u >>> _ = torch.manual_seed(42) >>> preds = torch.randint(10, (10,)) >>> target = torch.randint(10, (10,)) >>> theils_u(preds, target) tensor(0.8530)
theils_u_matrix¶
- torchmetrics.functional.nominal.theils_u_matrix(matrix, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Theil’s U statistic between a set of multiple variables.
This can serve as a convenient tool to compute Theil’s U statistic for analyses of correlation between categorical variables in your dataset.
- Parameters:
matrix (
Tensor
) – A tensor of categorical (nominal) data, where: - rows represent a number of data points - columns represent a number of categorical (nominal) featuresnan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Theil’s U statistic for a dataset of categorical variables
Example
>>> from torchmetrics.functional.nominal import theils_u_matrix >>> _ = torch.manual_seed(42) >>> matrix = torch.randint(0, 4, (200, 5)) >>> theils_u_matrix(matrix) tensor([[1.0000, 0.0202, 0.0142, 0.0196, 0.0353], [0.0202, 1.0000, 0.0070, 0.0136, 0.0065], [0.0143, 0.0070, 1.0000, 0.0125, 0.0206], [0.0198, 0.0137, 0.0125, 1.0000, 0.0312], [0.0352, 0.0065, 0.0204, 0.0308, 1.0000]])
Tschuprow’s T¶
Module Interface¶
- class torchmetrics.nominal.TschuprowsT(num_classes, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0, **kwargs)[source]¶
Compute Tschuprow’s T statistic measuring the association between two categorical (nominal) data series.
\[T = \sqrt{\frac{\chi^2 / n}{\sqrt{(r - 1) * (k - 1)}}}\]where
\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in
preds
andtarget
, respectively. Tschuprow’s T is a symmetric coefficient, i.e. \(T(preds, target) = T(target, preds)\), so order of input arguments does not matter. The output values lies in [0, 1] with 1 meaning the perfect association.As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the first data series with shape(batch_size,)
or(batch_size, num_classes)
, respectively.target
(Tensor
): Either 1D or 2D tensor of categorical (nominal) data from the second data series with shape(batch_size,)
or(batch_size, num_classes)
, respectively.
As output of
forward
andcompute
the metric returns the following output:tschuprows_t
(Tensor
): Scalar tensor containing the Tschuprow’s T statistic.
- Parameters:
num_classes (
int
) – Integer specifing the number of classesbias_correction (
bool
) – Indication of whether to use bias correction.nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If nan_strategy is not one of ‘replace’ and ‘drop’
ValueError – If nan_strategy is equal to ‘replace’ and nan_replace_value is not an int or float
Example:
>>> from torchmetrics.nominal import TschuprowsT >>> _ = torch.manual_seed(42) >>> preds = torch.randint(0, 4, (100,)) >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) >>> tschuprows_t = TschuprowsT(num_classes=5) >>> tschuprows_t(preds, target) tensor(0.4930)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.nominal import TschuprowsT >>> metric = TschuprowsT(num_classes=5) >>> metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.nominal import TschuprowsT >>> metric = TschuprowsT(num_classes=5) >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.nominal.tschuprows_t(preds, target, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Tschuprow’s T statistic measuring the association between two categorical (nominal) data series.
\[T = \sqrt{\frac{\chi^2 / n}{\sqrt{(r - 1) * (k - 1)}}}\]where
\[\chi^2 = \sum_{i,j} \ frac{\left(n_{ij} - \frac{n_{i.} n_{.j}}{n}\right)^2}{\frac{n_{i.} n_{.j}}{n}}\]where \(n_{ij}\) denotes the number of times the values \((A_i, B_j)\) are observed with \(A_i, B_j\) represent frequencies of values in
preds
andtarget
, respectively.Tschuprow’s T is a symmetric coefficient, i.e. \(T(preds, target) = T(target, preds)\).
The output values lies in [0, 1] with 1 meaning the perfect association.
- Parameters:
preds (
Tensor
) –1D or 2D tensor of categorical (nominal) data:
1D shape: (batch_size,)
2D shape: (batch_size, num_classes)
target (
Tensor
) –1D or 2D tensor of categorical (nominal) data:
1D shape: (batch_size,)
2D shape: (batch_size, num_classes)
bias_correction (
bool
) – Indication of whether to use bias correction.nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Tschuprow’s T statistic
Example
>>> from torchmetrics.functional.nominal import tschuprows_t >>> _ = torch.manual_seed(42) >>> preds = torch.randint(0, 4, (100,)) >>> target = torch.round(preds + torch.randn(100)).clamp(0, 4) >>> tschuprows_t(preds, target) tensor(0.4930)
tschuprows_t_matrix¶
- torchmetrics.functional.nominal.tschuprows_t_matrix(matrix, bias_correction=True, nan_strategy='replace', nan_replace_value=0.0)[source]¶
Compute Tschuprow’s T statistic between a set of multiple variables.
This can serve as a convenient tool to compute Tschuprow’s T statistic for analyses of correlation between categorical variables in your dataset.
- Parameters:
matrix (
Tensor
) –A tensor of categorical (nominal) data, where:
rows represent a number of data points
columns represent a number of categorical (nominal) features
bias_correction (
bool
) – Indication of whether to use bias correction.nan_strategy (
Literal
['replace'
,'drop'
]) – Indication of whether to replace or dropNaN
valuesnan_replace_value (
Union
[int
,float
,None
]) – Value to replaceNaN``s when ``nan_strategy = 'replace'
- Return type:
- Returns:
Tschuprow’s T statistic for a dataset of categorical variables
Example
>>> from torchmetrics.functional.nominal import tschuprows_t_matrix >>> _ = torch.manual_seed(42) >>> matrix = torch.randint(0, 4, (200, 5)) >>> tschuprows_t_matrix(matrix) tensor([[1.0000, 0.0637, 0.0000, 0.0542, 0.1337], [0.0637, 1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000, 0.0649], [0.0542, 0.0000, 0.0000, 1.0000, 0.1100], [0.1337, 0.0000, 0.0649, 0.1100, 1.0000]])
Cosine Similarity¶
Functional Interface¶
- torchmetrics.functional.pairwise_cosine_similarity(x, y=None, reduction=None, zero_diagonal=None)[source]¶
Calculate pairwise cosine similarity.
\[s_{cos}(x,y) = \frac{<x,y>}{||x|| \cdot ||y||} = \frac{\sum_{d=1}^D x_d \cdot y_d }{\sqrt{\sum_{d=1}^D x_i^2} \cdot \sqrt{\sum_{d=1}^D y_i^2}}\]If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).
- Parameters:
x (
Tensor
) – Tensor with shape[N, d]
reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
,None
]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reductionzero_diagonal (
Optional
[bool
]) – if the diagonal of the distance matrix should be set to 0. If only \(x\) is given this defaults toTrue
else if \(y\) is also given it defaults toFalse
- Return type:
- Returns:
A
[N,N]
matrix of distances if onlyx
is given, else a[N,M]
matrix
Example
>>> import torch >>> from torchmetrics.functional.pairwise import pairwise_cosine_similarity >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32) >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32) >>> pairwise_cosine_similarity(x, y) tensor([[0.5547, 0.8682], [0.5145, 0.8437], [0.5300, 0.8533]]) >>> pairwise_cosine_similarity(x) tensor([[0.0000, 0.9989, 0.9996], [0.9989, 0.0000, 0.9998], [0.9996, 0.9998, 0.0000]])
Euclidean Distance¶
Functional Interface¶
- torchmetrics.functional.pairwise_euclidean_distance(x, y=None, reduction=None, zero_diagonal=None)[source]¶
Calculate pairwise euclidean distances.
\[d_{euc}(x,y) = ||x - y||_2 = \sqrt{\sum_{d=1}^D (x_d - y_d)^2}\]If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).
- Parameters:
x (
Tensor
) – Tensor with shape[N, d]
reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
,None
]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reductionzero_diagonal (
Optional
[bool
]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False
- Return type:
- Returns:
A
[N,N]
matrix of distances if onlyx
is given, else a[N,M]
matrix
Example
>>> import torch >>> from torchmetrics.functional.pairwise import pairwise_euclidean_distance >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32) >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32) >>> pairwise_euclidean_distance(x, y) tensor([[3.1623, 2.0000], [5.3852, 4.1231], [8.9443, 7.6158]]) >>> pairwise_euclidean_distance(x) tensor([[0.0000, 2.2361, 5.8310], [2.2361, 0.0000, 3.6056], [5.8310, 3.6056, 0.0000]])
Linear Similarity¶
Functional Interface¶
- torchmetrics.functional.pairwise_linear_similarity(x, y=None, reduction=None, zero_diagonal=None)[source]¶
Calculate pairwise linear similarity.
\[s_{lin}(x,y) = <x,y> = \sum_{d=1}^D x_d \cdot y_d\]If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).
- Parameters:
x (
Tensor
) – Tensor with shape[N, d]
reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
,None
]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reductionzero_diagonal (
Optional
[bool
]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False
- Return type:
- Returns:
A
[N,N]
matrix of distances if onlyx
is given, else a[N,M]
matrix
Example
>>> import torch >>> from torchmetrics.functional.pairwise import pairwise_linear_similarity >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32) >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32) >>> pairwise_linear_similarity(x, y) tensor([[ 2., 7.], [ 3., 11.], [ 5., 18.]]) >>> pairwise_linear_similarity(x) tensor([[ 0., 21., 34.], [21., 0., 55.], [34., 55., 0.]])
Manhattan Distance¶
Functional Interface¶
- torchmetrics.functional.pairwise_manhattan_distance(x, y=None, reduction=None, zero_diagonal=None)[source]¶
Calculate pairwise manhattan distance.
\[d_{man}(x,y) = ||x-y||_1 = \sum_{d=1}^D |x_d - y_d|\]If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).
- Parameters:
x (
Tensor
) – Tensor with shape[N, d]
reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
,None
]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reductionzero_diagonal (
Optional
[bool
]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False
- Return type:
- Returns:
A
[N,N]
matrix of distances if onlyx
is given, else a[N,M]
matrix
Example
>>> import torch >>> from torchmetrics.functional.pairwise import pairwise_manhattan_distance >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32) >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32) >>> pairwise_manhattan_distance(x, y) tensor([[ 4., 2.], [ 7., 5.], [12., 10.]]) >>> pairwise_manhattan_distance(x) tensor([[0., 3., 8.], [3., 0., 5.], [8., 5., 0.]])
Minkowski Distance¶
Functional Interface¶
- torchmetrics.functional.pairwise_minkowski_distance(x, y=None, exponent=2, reduction=None, zero_diagonal=None)[source]¶
Calculate pairwise minkowski distances.
\[d_{minkowski}(x,y,p) = ||x - y||_p = \sqrt[p]{\sum_{d=1}^D (x_d - y_d)^p}\]If both \(x\) and \(y\) are passed in, the calculation will be performed pairwise between the rows of \(x\) and \(y\). If only \(x\) is passed in, the calculation will be performed between the rows of \(x\).
- Parameters:
x (
Tensor
) – Tensor with shape[N, d]
exponent (
Union
[int
,float
]) – int or float larger than 1, exponent to which the difference between preds and target is to be raisedreduction (
Optional
[Literal
['mean'
,'sum'
,'none'
,None
]]) – reduction to apply along the last dimension. Choose between ‘mean’, ‘sum’ (applied along column dimension) or ‘none’, None for no reductionzero_diagonal (
Optional
[bool
]) – if the diagonal of the distance matrix should be set to 0. If only x is given this defaults to True else if y is also given it defaults to False
- Return type:
- Returns:
A
[N,N]
matrix of distances if onlyx
is given, else a[N,M]
matrix
Example
>>> import torch >>> from torchmetrics.functional.pairwise import pairwise_minkowski_distance >>> x = torch.tensor([[2, 3], [3, 5], [5, 8]], dtype=torch.float32) >>> y = torch.tensor([[1, 0], [2, 1]], dtype=torch.float32) >>> pairwise_minkowski_distance(x, y, exponent=4) tensor([[3.0092, 2.0000], [5.0317, 4.0039], [8.1222, 7.0583]]) >>> pairwise_minkowski_distance(x, exponent=4) tensor([[0.0000, 2.0305, 5.1547], [2.0305, 0.0000, 3.1383], [5.1547, 3.1383, 0.0000]])
Concordance Corr. Coef.¶
Module Interface¶
- class torchmetrics.ConcordanceCorrCoef(num_outputs=1, **kwargs)[source]¶
Compute concordance correlation coefficient that measures the agreement between two variables.
\[\rho_c = \frac{2 \rho \sigma_x \sigma_y}{\sigma_x^2 + \sigma_y^2 + (\mu_x - \mu_y)^2}\]where \(\mu_x, \mu_y\) is the means for the two variables, \(\sigma_x^2, \sigma_y^2\) are the corresponding variances and rho is the pearson correlation coefficient between the two variables.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): either single output float tensor with shape(N,)
or multioutput float tensor of shape(N,d)
target
(Tensor
): either single output float tensor with shape(N,)
or multioutput float tensor of shape(N,d)
As output of
forward
andcompute
the metric returns the following output:concordance
(Tensor
): A scalar float tensor with the concordance coefficient(s) for non-multioutput input or a float tensor with shape(d,)
for multioutput input
- Parameters:
num_outputs (
int
) – Number of outputs in multioutput settingkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (single output regression):
>>> from torchmetrics.regression import ConcordanceCorrCoef >>> from torch import tensor >>> target = tensor([3, -0.5, 2, 7]) >>> preds = tensor([2.5, 0.0, 2, 8]) >>> concordance = ConcordanceCorrCoef() >>> concordance(preds, target) tensor(0.9777)
- Example (multi output regression):
>>> from torchmetrics.regression import ConcordanceCorrCoef >>> target = tensor([[3, -0.5], [2, 7]]) >>> preds = tensor([[2.5, 0.0], [2, 8]]) >>> concordance = ConcordanceCorrCoef(num_outputs=2) >>> concordance(preds, target) tensor([0.7273, 0.9887])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import ConcordanceCorrCoef >>> metric = ConcordanceCorrCoef() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import ConcordanceCorrCoef >>> metric = ConcordanceCorrCoef() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.concordance_corrcoef(preds, target)[source]¶
Compute concordance correlation coefficient that measures the agreement between two variables.
\[\rho_c = \frac{2 \rho \sigma_x \sigma_y}{\sigma_x^2 + \sigma_y^2 + (\mu_x - \mu_y)^2}\]where \(\mu_x, \mu_y\) is the means for the two variables, \(\sigma_x^2, \sigma_y^2\) are the corresponding variances and rho is the pearson correlation coefficient between the two variables.
- Parameters:
- Return type:
- Example (single output regression):
>>> from torchmetrics.functional.regression import concordance_corrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> concordance_corrcoef(preds, target) tensor([0.9777])
- Example (multi output regression):
>>> from torchmetrics.functional.regression import concordance_corrcoef >>> target = torch.tensor([[3, -0.5], [2, 7]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> concordance_corrcoef(preds, target) tensor([0.7273, 0.9887])
Cosine Similarity¶
Module Interface¶
- class torchmetrics.CosineSimilarity(reduction='sum', **kwargs)[source]¶
Compute the Cosine Similarity.
\[cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} = \frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}}\]where \(y\) is a tensor of target values, and \(x\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predicted float tensor with shape(N,d)
target
(Tensor
): Ground truth float tensor with shape(N,d)
As output of
forward
andcompute
the metric returns the following output:cosine_similarity
(Tensor
): A float tensor with the cosine similarity
- Parameters:
reduction (
Literal
['mean'
,'sum'
,'none'
,None
]) – how to reduce over the batch dimension using ‘sum’, ‘mean’ or ‘none’ (taking the individual scores)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.regression import CosineSimilarity >>> target = tensor([[0, 1], [1, 1]]) >>> preds = tensor([[0, 1], [0, 1]]) >>> cosine_similarity = CosineSimilarity(reduction = 'mean') >>> cosine_similarity(preds, target) tensor(0.8536)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import CosineSimilarity >>> metric = CosineSimilarity() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import CosineSimilarity >>> metric = CosineSimilarity() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.cosine_similarity(preds, target, reduction='sum')[source]¶
Compute the Cosine Similarity.
\[cos_{sim}(x,y) = \frac{x \cdot y}{||x|| \cdot ||y||} = \frac{\sum_{i=1}^n x_i y_i}{\sqrt{\sum_{i=1}^n x_i^2}\sqrt{\sum_{i=1}^n y_i^2}}\]where \(y\) is a tensor of target values, and \(x\) is a tensor of predictions.
- Parameters:
- Return type:
Example
>>> from torchmetrics.functional.regression import cosine_similarity >>> target = torch.tensor([[1, 2, 3, 4], ... [1, 2, 3, 4]]) >>> preds = torch.tensor([[1, 2, 3, 4], ... [-1, -2, -3, -4]]) >>> cosine_similarity(preds, target, 'none') tensor([ 1.0000, -1.0000])
Explained Variance¶
Module Interface¶
- class torchmetrics.ExplainedVariance(multioutput='uniform_average', **kwargs)[source]¶
Compute explained variance.
\[\text{ExplainedVariance} = 1 - \frac{\text{Var}(y - \hat{y})}{\text{Var}(y)}\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predictions from model in float tensor with shape(N,)
or(N, ...)
(multioutput)target
(Tensor
): Ground truth values in long tensor with shape(N,)
or(N, ...)
(multioutput)
As output of
forward
andcompute
the metric returns the following output:explained_variance
(Tensor
): A tensor with the explained variance(s)
In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument
multioutput
for changing this behavior.- Parameters:
multioutput (
Literal
['raw_values'
,'uniform_average'
,'variance_weighted'
]) –Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is
'uniform_average'
.):'raw_values'
returns full set of scores'uniform_average'
scores are uniformly averaged'variance_weighted'
scores are weighted by their individual variances
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
multioutput
is not one of"raw_values"
,"uniform_average"
or"variance_weighted"
.
Example
>>> from torch import tensor >>> from torchmetrics.regression import ExplainedVariance >>> target = tensor([3, -0.5, 2, 7]) >>> preds = tensor([2.5, 0.0, 2, 8]) >>> explained_variance = ExplainedVariance() >>> explained_variance(preds, target) tensor(0.9572)
>>> target = tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = tensor([[0, 2], [-1, 2], [8, -5]]) >>> explained_variance = ExplainedVariance(multioutput='raw_values') >>> explained_variance(preds, target) tensor([0.9677, 1.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import ExplainedVariance >>> metric = ExplainedVariance() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import ExplainedVariance >>> metric = ExplainedVariance() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.explained_variance(preds, target, multioutput='uniform_average')[source]¶
Compute explained variance.
- Parameters:
preds (
Tensor
) – estimated labelstarget (
Tensor
) – ground truth labelsmultioutput (
Literal
['raw_values'
,'uniform_average'
,'variance_weighted'
]) –Defines aggregation in the case of multiple output scores. Can be one of the following strings):
'raw_values'
returns full set of scores'uniform_average'
scores are uniformly averaged'variance_weighted'
scores are weighted by their individual variances
- Return type:
Example
>>> from torchmetrics.functional.regression import explained_variance >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> explained_variance(preds, target) tensor(0.9572)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> explained_variance(preds, target, multioutput='raw_values') tensor([0.9677, 1.0000])
Kendall Rank Corr. Coef.¶
Module Interface¶
- class torchmetrics.KendallRankCorrCoef(variant='b', t_test=False, alternative='two-sided', num_outputs=1, **kwargs)[source]¶
Compute Kendall Rank Correlation Coefficient.
\[tau_a = \frac{C - D}{C + D}\]where \(C\) represents concordant pairs, \(D\) stands for discordant pairs.
\[tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}\]where \(C\) represents concordant pairs, \(D\) stands for discordant pairs and \(T\) represents a total number of ties.
\[tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}\]where \(C\) represents concordant pairs, \(D\) stands for discordant pairs, \(n\) is a total number of observations and \(m\) is a
min
of unique values inpreds
andtarget
sequence.Definitions according to Definition according to The Treatment of Ties in Ranking Problems.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Sequence of data in float tensor of either shape(N,)
or(N,d)
target
(Tensor
): Sequence of data in float tensor of either shape(N,)
or(N,d)
As output of
forward
andcompute
the metric returns the following output:kendall
(Tensor
): A tensor with the correlation tau statistic, and if it is not None, the p-value of corresponding statistical test.
- Parameters:
variant (
Literal
['a'
,'b'
,'c'
]) – Indication of which variant of Kendall’s tau to be usedt_test (
bool
) – Indication whether to run t-testalternative (
Optional
[Literal
['two-sided'
,'less'
,'greater'
]]) – Alternative hypothesis for t-test. Possible values: - ‘two-sided’: the rank correlation is nonzero - ‘less’: the rank correlation is negative (less than zero) - ‘greater’: the rank correlation is positive (greater than zero)num_outputs (
int
) – Number of outputs in multioutput settingkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
t_test
is not of a type boolValueError – If
t_test=True
andalternative=None
- Example (single output regression):
>>> from torch import tensor >>> from torchmetrics.regression import KendallRankCorrCoef >>> preds = tensor([2.5, 0.0, 2, 8]) >>> target = tensor([3, -0.5, 2, 1]) >>> kendall = KendallRankCorrCoef() >>> kendall(preds, target) tensor(0.3333)
- Example (multi output regression):
>>> from torchmetrics.regression import KendallRankCorrCoef >>> preds = tensor([[2.5, 0.0], [2, 8]]) >>> target = tensor([[3, -0.5], [2, 1]]) >>> kendall = KendallRankCorrCoef(num_outputs=2) >>> kendall(preds, target) tensor([1., 1.])
- Example (single output regression with t-test):
>>> from torchmetrics.regression import KendallRankCorrCoef >>> preds = tensor([2.5, 0.0, 2, 8]) >>> target = tensor([3, -0.5, 2, 1]) >>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided') >>> kendall(preds, target) (tensor(0.3333), tensor(0.4969))
- Example (multi output regression with t-test):
>>> from torchmetrics.regression import KendallRankCorrCoef >>> preds = tensor([[2.5, 0.0], [2, 8]]) >>> target = tensor([[3, -0.5], [2, 1]]) >>> kendall = KendallRankCorrCoef(t_test=True, alternative='two-sided', num_outputs=2) >>> kendall(preds, target) (tensor([1., 1.]), tensor([nan, nan]))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import KendallRankCorrCoef >>> metric = KendallRankCorrCoef() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import KendallRankCorrCoef >>> metric = KendallRankCorrCoef() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.kendall_rank_corrcoef(preds, target, variant='b', t_test=False, alternative='two-sided')[source]¶
Compute Kendall Rank Correlation Coefficient.
\[tau_a = \frac{C - D}{C + D}\]where \(C\) represents concordant pairs, \(D\) stands for discordant pairs.
\[tau_b = \frac{C - D}{\sqrt{(C + D + T_{preds}) * (C + D + T_{target})}}\]where \(C\) represents concordant pairs, \(D\) stands for discordant pairs and \(T\) represents a total number of ties.
\[tau_c = 2 * \frac{C - D}{n^2 * \frac{m - 1}{m}}\]where \(C\) represents concordant pairs, \(D\) stands for discordant pairs, \(n\) is a total number of observations and \(m\) is a
min
of unique values inpreds
andtarget
sequence.Definitions according to Definition according to The Treatment of Ties in Ranking Problems.
- Parameters:
preds (
Tensor
) – Sequence of data of either shape(N,)
or(N,d)
target (
Tensor
) – Sequence of data of either shape(N,)
or(N,d)
variant (
Literal
['a'
,'b'
,'c'
]) – Indication of which variant of Kendall’s tau to be usedt_test (
bool
) – Indication whether to run t-testalternative (
Optional
[Literal
['two-sided'
,'less'
,'greater'
]]) – Alternative hypothesis for t-test. Possible values: - ‘two-sided’: the rank correlation is nonzero - ‘less’: the rank correlation is negative (less than zero) - ‘greater’: the rank correlation is positive (greater than zero)
- Return type:
- Returns:
Correlation tau statistic (Optional) p-value of corresponding statistical test (asymptotic)
- Raises:
ValueError – If
t_test
is not of a type boolValueError – If
t_test=True
andalternative=None
- Example (single output regression):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> target = torch.tensor([3, -0.5, 2, 1]) >>> kendall_rank_corrcoef(preds, target) tensor(0.3333)
- Example (multi output regression):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> kendall_rank_corrcoef(preds, target) tensor([1., 1.])
- Example (single output regression with t-test)
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> target = torch.tensor([3, -0.5, 2, 1]) >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') (tensor(0.3333), tensor(0.4969))
- Example (multi output regression with t-test):
>>> from torchmetrics.functional.regression import kendall_rank_corrcoef >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> target = torch.tensor([[3, -0.5], [2, 1]]) >>> kendall_rank_corrcoef(preds, target, t_test=True, alternative='two-sided') (tensor([1., 1.]), tensor([nan, nan]))
KL Divergence¶
Module Interface¶
- class torchmetrics.KLDivergence(log_prob=False, reduction='mean', **kwargs)[source]¶
Compute the KL divergence.
\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).
As input to
forward
andupdate
the metric accepts the following input:p
(Tensor
): a data distribution with shape(N, d)
q
(Tensor
): prior or approximate distribution with shape(N, d)
As output of
forward
andcompute
the metric returns the following output:kl_divergence
(Tensor
): A tensor with the KL divergence
- Parameters:
log_prob (
bool
) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.reduction (
Literal
['mean'
,'sum'
,'none'
,None
]) –Determines how to reduce over the
N
/batch dimension:'mean'
[default]: Averages score across samples'sum'
: Sum score across samples'none'
orNone
: Returns score per sample
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
TypeError – If
log_prob
is not anbool
.ValueError – If
reduction
is not one of'mean'
,'sum'
,'none'
orNone
.
Note
Half precision is only support on GPU for this metric
Example
>>> from torch import tensor >>> from torchmetrics.regression import KLDivergence >>> p = tensor([[0.36, 0.48, 0.16]]) >>> q = tensor([[1/3, 1/3, 1/3]]) >>> kl_divergence = KLDivergence() >>> kl_divergence(p, q) tensor(0.0853)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import KLDivergence >>> metric = KLDivergence() >>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import KLDivergence >>> metric = KLDivergence() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.kl_divergence(p, q, log_prob=False, reduction='mean')[source]¶
Compute KL divergence.
\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).
- Parameters:
p (
Tensor
) – data distribution with shape[N, d]
q (
Tensor
) – prior or approximate distribution with shape[N, d]
log_prob (
bool
) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1reduction (
Literal
['mean'
,'sum'
,'none'
,None
]) –Determines how to reduce over the
N
/batch dimension:'mean'
[default]: Averages score across samples'sum'
: Sum score across samples'none'
orNone
: Returns score per sample
- Return type:
Example
>>> from torch import tensor >>> p = tensor([[0.36, 0.48, 0.16]]) >>> q = tensor([[1/3, 1/3, 1/3]]) >>> kl_divergence(p, q) tensor(0.0853)
Log Cosh Error¶
Module Interface¶
- class torchmetrics.LogCoshError(num_outputs=1, **kwargs)[source]¶
Compute the LogCosh Error.
\[\text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Estimated labels with shape(batch_size,)
or(batch_size, num_outputs)
target
(Tensor
): Ground truth labels with shape(batch_size,)
or(batch_size, num_outputs)
As output of
forward
andcompute
the metric returns the following output:log_cosh_error
(Tensor
): A tensor with the log cosh error
- Parameters:
num_outputs (
int
) – Number of outputs in multioutput settingkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (single output regression)::
>>> from torchmetrics.regression import LogCoshError >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) >>> log_cosh_error = LogCoshError() >>> log_cosh_error(preds, target) tensor(0.3523)
- Example (multi output regression)::
>>> from torchmetrics.regression import LogCoshError >>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]]) >>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]]) >>> log_cosh_error = LogCoshError(num_outputs=3) >>> log_cosh_error(preds, target) tensor([0.9176, 0.4277, 0.2194])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import LogCoshError >>> metric = LogCoshError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import LogCoshError >>> metric = LogCoshError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.log_cosh_error(preds, target)[source]¶
Compute the LogCosh Error.
\[\text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
- Parameters:
- Return type:
- Returns:
Tensor with LogCosh error
- Example (single output regression)::
>>> from torchmetrics.functional.regression import log_cosh_error >>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0]) >>> target = torch.tensor([2.5, 5.0, 4.0, 8.0]) >>> log_cosh_error(preds, target) tensor(0.3523)
- Example (multi output regression)::
>>> from torchmetrics.functional.regression import log_cosh_error >>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]]) >>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]]) >>> log_cosh_error(preds, target) tensor([0.9176, 0.4277, 0.2194])
Mean Absolute Error (MAE)¶
Module Interface¶
- class torchmetrics.MeanAbsoluteError(**kwargs)[source]¶
Compute Mean Absolute Error (MAE).
\[\text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} |\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:mean_absolute_error
(Tensor
): A tensor with the mean absolute error over the state
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.regression import MeanAbsoluteError >>> target = tensor([3.0, -0.5, 2.0, 7.0]) >>> preds = tensor([2.5, 0.0, 2.0, 8.0]) >>> mean_absolute_error = MeanAbsoluteError() >>> mean_absolute_error(preds, target) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import MeanAbsoluteError >>> metric = MeanAbsoluteError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import MeanAbsoluteError >>> metric = MeanAbsoluteError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.mean_absolute_error(preds, target)[source]¶
Compute mean absolute error.
- Parameters:
- Return type:
- Returns:
Tensor with MAE
Example
>>> from torchmetrics.functional.regression import mean_absolute_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_absolute_error(x, y) tensor(0.2500)
Mean Absolute Percentage Error (MAPE)¶
Module Interface¶
- class torchmetrics.MeanAbsolutePercentageError(**kwargs)[source]¶
Compute Mean Absolute Percentage Error (MAPE).
\[\text{MAPE} = \frac{1}{n}\sum_{i=1}^n\frac{| y_i - \hat{y_i} |}{\max(\epsilon, | y_i |)}\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:mean_abs_percentage_error
(Tensor
): A tensor with the mean absolute percentage error over state
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Note
MAPE output is a non-negative floating point. Best result is
0.0
. But it is important to note that, bad predictions, can lead to arbitarily large values. Especially when sometarget
values are close to 0. This MAPE implementation returns a very large number instead ofinf
.Example
>>> from torch import tensor >>> from torchmetrics.regression import MeanAbsolutePercentageError >>> target = tensor([1, 10, 1e6]) >>> preds = tensor([0.9, 15, 1.2e6]) >>> mean_abs_percentage_error = MeanAbsolutePercentageError() >>> mean_abs_percentage_error(preds, target) tensor(0.2667)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import MeanAbsolutePercentageError >>> metric = MeanAbsolutePercentageError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import MeanAbsolutePercentageError >>> metric = MeanAbsolutePercentageError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.mean_absolute_percentage_error(preds, target)[source]¶
Compute mean absolute percentage error.
- Parameters:
- Return type:
- Returns:
Tensor with MAPE
Note
The epsilon value is taken from scikit-learn’s implementation of MAPE.
Example
>>> from torchmetrics.functional.regression import mean_absolute_percentage_error >>> target = torch.tensor([1, 10, 1e6]) >>> preds = torch.tensor([0.9, 15, 1.2e6]) >>> mean_absolute_percentage_error(preds, target) tensor(0.2667)
Mean Squared Error (MSE)¶
Module Interface¶
- class torchmetrics.MeanSquaredError(squared=True, num_outputs=1, **kwargs)[source]¶
Compute mean squared error (MSE).
\[\text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:mean_squared_error
(Tensor
): A tensor with the mean squared error
- Parameters:
squared (
bool
) – If True returns MSE value, if False returns RMSE value.num_outputs (
int
) – Number of outputs in multioutput settingkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example::
Single output mse computation:
>>> from torch import tensor >>> from torchmetrics.regression import MeanSquaredError >>> target = tensor([2.5, 5.0, 4.0, 8.0]) >>> preds = tensor([3.0, 5.0, 2.5, 7.0]) >>> mean_squared_error = MeanSquaredError() >>> mean_squared_error(preds, target) tensor(0.8750)
- Example::
Multioutput mse computation:
>>> from torch import tensor >>> from torchmetrics.regression import MeanSquaredError >>> target = tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) >>> preds = tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) >>> mean_squared_error = MeanSquaredError(num_outputs=3) >>> mean_squared_error(preds, target) tensor([1., 4., 9.])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import MeanSquaredError >>> metric = MeanSquaredError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import MeanSquaredError >>> metric = MeanSquaredError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.mean_squared_error(preds, target, squared=True, num_outputs=1)[source]¶
Compute mean squared error.
- Parameters:
- Return type:
- Returns:
Tensor with MSE
Example
>>> from torchmetrics.functional.regression import mean_squared_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_squared_error(x, y) tensor(0.2500)
Mean Squared Log Error (MSLE)¶
Module Interface¶
- class torchmetrics.MeanSquaredLogError(**kwargs)[source]¶
Compute mean squared logarithmic error (MSLE).
\[\text{MSLE} = \frac{1}{N}\sum_i^N (\log_e(1 + y_i) - \log_e(1 + \hat{y_i}))^2\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:mean_squared_log_error
(Tensor
): A tensor with the mean squared log error
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import tensor >>> from torchmetrics.regression import MeanSquaredLogError >>> target = tensor([2.5, 5, 4, 8]) >>> preds = tensor([3, 5, 2.5, 7]) >>> mean_squared_log_error = MeanSquaredLogError() >>> mean_squared_log_error(preds, target) tensor(0.0397)
Note
Half precision is only support on GPU for this metric
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import MeanSquaredLogError >>> metric = MeanSquaredLogError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import MeanSquaredLogError >>> metric = MeanSquaredLogError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.mean_squared_log_error(preds, target)[source]¶
Compute mean squared log error.
- Parameters:
- Return type:
- Returns:
Tensor with RMSLE
Example
>>> from torchmetrics.functional.regression import mean_squared_log_error >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> mean_squared_log_error(x, y) tensor(0.0207)
Note
Half precision is only support on GPU for this metric
Minkowski Distance¶
Module Interface¶
- class torchmetrics.MinkowskiDistance(p, **kwargs)[source]¶
Compute Minkowski Distance.
\[d_{\text{Minkowski}} = \sum_{i}^N (| y_i - \hat{y_i} |^p)^\frac{1}{p}\]- where
- math:
y is a tensor of target values,
- math:
hat{y} is a tensor of predictions,
- math:
p is a non-negative integer or floating-point number
This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski distance with p=2.
- Parameters:
p (
float
) – int or float larger than 1, exponent to which the difference between preds and target is to be raisedkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.regression import MinkowskiDistance >>> target = tensor([1.0, 2.8, 3.5, 4.5]) >>> preds = tensor([6.1, 2.11, 3.1, 5.6]) >>> minkowski_distance = MinkowskiDistance(3) >>> minkowski_distance(preds, target) tensor(5.1220)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import MinkowskiDistance >>> metric = MinkowskiDistance(p=3) >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import MinkowskiDistance >>> metric = MinkowskiDistance(p=3) >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.minkowski_distance(preds, targets, p)[source]¶
Compute the Minkowski distance.
\[\begin{split}d_{\text{Minkowski}} = \\sum_{i}^N (| y_i - \\hat{y_i} |^p)^\frac{1}{p}\end{split}\]This metric can be seen as generalized version of the standard euclidean distance which corresponds to minkowski distance with p=2.
- Parameters:
- Return type:
- Returns:
Tensor with the Minkowski distance
Example
>>> from torchmetrics.functional.regression import minkowski_distance >>> x = torch.tensor([1.0, 2.8, 3.5, 4.5]) >>> y = torch.tensor([6.1, 2.11, 3.1, 5.6]) >>> minkowski_distance(x, y, p=3) tensor(5.1220)
Pearson Corr. Coef.¶
Module Interface¶
- class torchmetrics.PearsonCorrCoef(num_outputs=1, **kwargs)[source]¶
Compute Pearson Correlation Coefficient.
\[P_{corr}(x,y) = \frac{cov(x,y)}{\sigma_x \sigma_y}\]Where \(y\) is a tensor of target values, and \(x\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): either single output float tensor with shape(N,)
or multioutput float tensor of shape(N,d)
target
(Tensor
): either single output tensor with shape(N,)
or multioutput tensor of shape(N,d)
As output of
forward
andcompute
the metric returns the following output:pearson
(Tensor
): A tensor with the Pearson Correlation Coefficient
- Parameters:
num_outputs (
int
) – Number of outputs in multioutput settingkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (single output regression):
>>> from torchmetrics.regression import PearsonCorrCoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> pearson = PearsonCorrCoef() >>> pearson(preds, target) tensor(0.9849)
- Example (multi output regression):
>>> from torchmetrics.regression import PearsonCorrCoef >>> target = torch.tensor([[3, -0.5], [2, 7]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> pearson = PearsonCorrCoef(num_outputs=2) >>> pearson(preds, target) tensor([1., 1.])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import PearsonCorrCoef >>> metric = PearsonCorrCoef() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import PearsonCorrCoef >>> metric = PearsonCorrCoef() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.pearson_corrcoef(preds, target)[source]¶
Compute pearson correlation coefficient.
- Parameters:
- Return type:
- Example (single output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> pearson_corrcoef(preds, target) tensor(0.9849)
- Example (multi output regression):
>>> from torchmetrics.functional.regression import pearson_corrcoef >>> target = torch.tensor([[3, -0.5], [2, 7]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> pearson_corrcoef(preds, target) tensor([1., 1.])
R2 Score¶
Module Interface¶
- class torchmetrics.R2Score(num_outputs=1, adjusted=0, multioutput='uniform_average', **kwargs)[source]¶
Compute r2 score also known as R2 Score_Coefficient Determination.
\[R^2 = 1 - \frac{SS_{res}}{SS_{tot}}\]where \(SS_{res}=\sum_i (y_i - f(x_i))^2\) is the sum of residual squares, and \(SS_{tot}=\sum_i (y_i - \bar{y})^2\) is total sum of squares. Can also calculate adjusted r2 score given by
\[R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}\]where the parameter \(k\) (the number of independent regressors) should be provided as the adjusted argument. The score is only proper defined when \(SS_{tot}\neq 0\), which can happen for near constant targets. In this case a score of 0 is returned. By definition the score is bounded between 0 and 1, where 1 corresponds to the predictions exactly matching the targets.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predictions from model in float tensor with shape(N,)
or(N, M)
(multioutput)target
(Tensor
): Ground truth values in float tensor with shape(N,)
or(N, M)
(multioutput)
As output of
forward
andcompute
the metric returns the following output:r2score
(Tensor
): A tensor with the r2 score(s)
In the case of multioutput, as default the variances will be uniformly averaged over the additional dimensions. Please see argument
multioutput
for changing this behavior.- Parameters:
num_outputs (
int
) – Number of outputs in multioutput settingadjusted (
int
) – number of independent regressors for calculating adjusted r2 score.multioutput (
str
) –Defines aggregation in the case of multiple output scores. Can be one of the following strings:
'raw_values'
returns full set of scores'uniform_average'
scores are uniformly averaged'variance_weighted'
scores are weighted by their individual variances
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
adjusted
parameter is not an integer larger or equal to 0.ValueError – If
multioutput
is not one of"raw_values"
,"uniform_average"
or"variance_weighted"
.
Example
>>> from torchmetrics.regression import R2Score >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> r2score = R2Score() >>> r2score(preds, target) tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') >>> r2score(preds, target) tensor([0.9654, 0.9082])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import R2Score >>> metric = R2Score() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import R2Score >>> metric = R2Score() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.r2_score(preds, target, adjusted=0, multioutput='uniform_average')[source]¶
Compute r2 score also known as R2 Score_Coefficient Determination.
\[R^2 = 1 - \frac{SS_{res}}{SS_{tot}}\]where \(SS_{res}=\sum_i (y_i - f(x_i))^2\) is the sum of residual squares, and \(SS_{tot}=\sum_i (y_i - \bar{y})^2\) is total sum of squares. Can also calculate adjusted r2 score given by
\[R^2_{adj} = 1 - \frac{(1-R^2)(n-1)}{n-k-1}\]where the parameter \(k\) (the number of independent regressors) should be provided as the
adjusted
argument.- Parameters:
preds (
Tensor
) – estimated labelstarget (
Tensor
) – ground truth labelsadjusted (
int
) – number of independent regressors for calculating adjusted r2 score.multioutput (
str
) –Defines aggregation in the case of multiple output scores. Can be one of the following strings:
'raw_values'
returns full set of scores'uniform_average'
scores are uniformly averaged'variance_weighted'
scores are weighted by their individual variances
- Raises:
ValueError – If both
preds
andtargets
are not1D
or2D
tensors.ValueError – If
len(preds)
is less than2
since at least2
sampels are needed to calculate r2 score.ValueError – If
multioutput
is not one ofraw_values
,uniform_average
orvariance_weighted
.ValueError – If
adjusted
is not aninteger
greater than0
.
- Return type:
Example
>>> from torchmetrics.functional.regression import r2_score >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> r2_score(preds, target) tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2_score(preds, target, multioutput='raw_values') tensor([0.9654, 0.9082])
Relative Squared Error (RSE)¶
Module Interface¶
- class torchmetrics.RelativeSquaredError(num_outputs=1, squared=True, **kwargs)[source]¶
Computes the relative squared error (RSE).
\[\text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}\]Where \(y\) is a tensor of target values with mean \(\overline{y}\), and \(\hat{y}\) is a tensor of predictions.
If num_outputs > 1, the returned value is averaged over all the outputs.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predictions from model in float tensor with shape(N,)
or(N, M)
(multioutput)target
(Tensor
): Ground truth values in float tensor with shape(N,)
or(N, M)
(multioutput)
As output of
forward
andcompute
the metric returns the following output:rse
(Tensor
): A tensor with the RSE score(s)
- Parameters:
num_outputs (
int
) – Number of outputs in multioutput settingsquared (
bool
) – If True returns RSE value, if False returns RRSE value.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.regression import RelativeSquaredError >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> relative_squared_error = RelativeSquaredError() >>> relative_squared_error(preds, target) tensor(0.0514)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import RelativeSquaredError >>> metric = RelativeSquaredError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import RelativeSquaredError >>> metric = RelativeSquaredError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.relative_squared_error(preds, target, squared=True)[source]¶
Computes the relative squared error (RSE).
\[\text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2}\]Where \(y\) is a tensor of target values with mean \(\overline{y}\), and \(\hat{y}\) is a tensor of predictions.
If preds and targets are 2D tensors, the RSE is averaged over the second dim.
- Parameters:
- Return type:
- Returns:
Tensor with RSE
Example
>>> from torchmetrics.functional.regression import relative_squared_error >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> relative_squared_error(preds, target) tensor(0.0514)
Spearman Corr. Coef.¶
Module Interface¶
- class torchmetrics.SpearmanCorrCoef(num_outputs=1, **kwargs)[source]¶
Compute spearmans rank correlation coefficient.
where \(rg_x\) and \(rg_y\) are the rank associated to the variables \(x\) and \(y\). Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predictions from model in float tensor with shape(N,d)
target
(Tensor
): Ground truth values in float tensor with shape(N,d)
As output of
forward
andcompute
the metric returns the following output:spearman
(Tensor
): A tensor with the spearman correlation(s)
- Parameters:
num_outputs (
int
) – Number of outputs in multioutput settingkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example (single output regression):
>>> from torch import tensor >>> from torchmetrics.regression import SpearmanCorrCoef >>> target = tensor([3, -0.5, 2, 7]) >>> preds = tensor([2.5, 0.0, 2, 8]) >>> spearman = SpearmanCorrCoef() >>> spearman(preds, target) tensor(1.0000)
- Example (multi output regression):
>>> from torchmetrics.regression import SpearmanCorrCoef >>> target = tensor([[3, -0.5], [2, 7]]) >>> preds = tensor([[2.5, 0.0], [2, 8]]) >>> spearman = SpearmanCorrCoef(num_outputs=2) >>> spearman(preds, target) tensor([1.0000, 1.0000])
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import SpearmanCorrCoef >>> metric = SpearmanCorrCoef() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import SpearmanCorrCoef >>> metric = SpearmanCorrCoef() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.spearman_corrcoef(preds, target)[source]¶
Compute spearmans rank correlation coefficient.
where \(rg_x\) and \(rg_y\) are the rank associated to the variables x and y. Spearmans correlations coefficient corresponds to the standard pearsons correlation coefficient calculated on the rank variables.
- Parameters:
- Return type:
- Example (single output regression):
>>> from torchmetrics.functional.regression import spearman_corrcoef >>> target = torch.tensor([3, -0.5, 2, 7]) >>> preds = torch.tensor([2.5, 0.0, 2, 8]) >>> spearman_corrcoef(preds, target) tensor(1.0000)
- Example (multi output regression):
>>> from torchmetrics.functional.regression import spearman_corrcoef >>> target = torch.tensor([[3, -0.5], [2, 7]]) >>> preds = torch.tensor([[2.5, 0.0], [2, 8]]) >>> spearman_corrcoef(preds, target) tensor([1.0000, 1.0000])
Symmetric Mean Absolute Percentage Error (SMAPE)¶
Module Interface¶
- class torchmetrics.SymmetricMeanAbsolutePercentageError(**kwargs)[source]¶
Compute symmetric mean absolute percentage error (SMAPE).
\[\text{SMAPE} = \frac{2}{n}\sum_1^n\frac{| y_i - \hat{y_i} |}{\max(| y_i | + | \hat{y_i} |, \epsilon)}\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:smape
(Tensor
): A tensor with non-negative floating point smape value between 0 and 1
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError >>> target = tensor([1, 10, 1e6]) >>> preds = tensor([0.9, 15, 1.2e6]) >>> smape = SymmetricMeanAbsolutePercentageError() >>> smape(preds, target) tensor(0.2290)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError >>> metric = SymmetricMeanAbsolutePercentageError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import SymmetricMeanAbsolutePercentageError >>> metric = SymmetricMeanAbsolutePercentageError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.symmetric_mean_absolute_percentage_error(preds, target)[source]¶
Compute symmetric mean absolute percentage error (SMAPE).
\[\text{SMAPE} = \frac{2}{n}\sum_1^n\frac{| y_i - \hat{y_i} |}{max(| y_i | + | \hat{y_i} |, \epsilon)}\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
- Parameters:
- Return type:
- Returns:
Tensor with SMAPE.
Example
>>> from torchmetrics.functional.regression import symmetric_mean_absolute_percentage_error >>> target = torch.tensor([1, 10, 1e6]) >>> preds = torch.tensor([0.9, 15, 1.2e6]) >>> symmetric_mean_absolute_percentage_error(preds, target) tensor(0.2290)
Tweedie Deviance Score¶
Module Interface¶
- class torchmetrics.TweedieDevianceScore(power=0.0, **kwargs)[source]¶
Compute the Tweedie Deviance Score.
\[\begin{split}deviance\_score(\hat{y},y) = \begin{cases} (\hat{y} - y)^2, & \text{for }p=0\\ 2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y), & \text{for }p=1\\ 2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1), & \text{for }p=2\\ 2 * (\frac{(max(y,0))^{2 - p}}{(1 - p)(2 - p)} - \frac{y(\hat{y})^{1 - p}}{1 - p} + \frac{( \hat{y})^{2 - p}}{2 - p}), & \text{otherwise} \end{cases}\end{split}\]where \(y\) is a tensor of targets values, \(\hat{y}\) is a tensor of predictions, and \(p\) is the power.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Predicted float tensor with shape(N,...)
target
(Tensor
): Ground truth float tensor with shape(N,...)
As output of
forward
andcompute
the metric returns the following output:deviance_score
(Tensor
): A tensor with the deviance score
- Parameters:
power (
float
) –power < 0 : Extreme stable distribution. (Requires: preds > 0.)
power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)
power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)
1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)
power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)
power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)
otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.regression import TweedieDevianceScore >>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0]) >>> deviance_score = TweedieDevianceScore(power=2) >>> deviance_score(preds, targets) tensor(1.2083)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import TweedieDevianceScore >>> metric = TweedieDevianceScore() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import TweedieDevianceScore >>> metric = TweedieDevianceScore() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.tweedie_deviance_score(preds, targets, power=0.0)[source]¶
Compute the Tweedie Deviance Score.
\[\begin{split}deviance\_score(\hat{y},y) = \begin{cases} (\hat{y} - y)^2, & \text{for }p=0\\ 2 * (y * log(\frac{y}{\hat{y}}) + \hat{y} - y), & \text{for }p=1\\ 2 * (log(\frac{\hat{y}}{y}) + \frac{y}{\hat{y}} - 1), & \text{for }p=2\\ 2 * (\frac{(max(y,0))^{2 - p}}{(1 - p)(2 - p)} - \frac{y(\hat{y})^{1 - p}}{1 - p} + \frac{( \hat{y})^{2 - p}}{2 - p}), & \text{otherwise} \end{cases}\end{split}\]where \(y\) is a tensor of targets values, \(\hat{y}\) is a tensor of predictions, and \(p\) is the power.
- Parameters:
preds (
Tensor
) – Predicted tensor with shape(N,...)
targets (
Tensor
) – Ground truth tensor with shape(N,...)
power (
float
) –power < 0 : Extreme stable distribution. (Requires: preds > 0.)
power = 0 : Normal distribution. (Requires: targets and preds can be any real numbers.)
power = 1 : Poisson distribution. (Requires: targets >= 0 and y_pred > 0.)
1 < p < 2 : Compound Poisson distribution. (Requires: targets >= 0 and preds > 0.)
power = 2 : Gamma distribution. (Requires: targets > 0 and preds > 0.)
power = 3 : Inverse Gaussian distribution. (Requires: targets > 0 and preds > 0.)
otherwise : Positive stable distribution. (Requires: targets > 0 and preds > 0.)
- Return type:
Example
>>> from torchmetrics.functional.regression import tweedie_deviance_score >>> targets = torch.tensor([1.0, 2.0, 3.0, 4.0]) >>> preds = torch.tensor([4.0, 3.0, 2.0, 1.0]) >>> tweedie_deviance_score(preds, targets, power=2) tensor(1.2083)
Weighted MAPE¶
Module Interface¶
- class torchmetrics.WeightedMeanAbsolutePercentageError(**kwargs)[source]¶
Compute weighted mean absolute percentage error (WMAPE).
The output of WMAPE metric is a non-negative floating point, where the optimal value is 0. It is computes as:
\[\text{WMAPE} = \frac{\sum_{t=1}^n | y_t - \hat{y}_t | }{\sum_{t=1}^n |y_t| }\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
As input to
forward
andupdate
the metric accepts the following input:As output of
forward
andcompute
the metric returns the following output:wmape
(Tensor
): A tensor with non-negative floating point wmape value between 0 and 1
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> preds = torch.randn(20,) >>> target = torch.randn(20,) >>> wmape = WeightedMeanAbsolutePercentageError() >>> wmape(preds, target) tensor(1.3967)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError >>> metric = WeightedMeanAbsolutePercentageError() >>> metric.update(randn(10,), randn(10,)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import WeightedMeanAbsolutePercentageError >>> metric = WeightedMeanAbsolutePercentageError() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,), randn(10,))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.weighted_mean_absolute_percentage_error(preds, target)[source]¶
Compute weighted mean absolute percentage error (WMAPE).
The output of WMAPE metric is a non-negative floating point, where the optimal value is 0. It is computes as:
\[\text{WMAPE} = \frac{\sum_{t=1}^n | y_t - \hat{y}_t | }{\sum_{t=1}^n |y_t| }\]Where \(y\) is a tensor of target values, and \(\hat{y}\) is a tensor of predictions.
- Parameters:
- Return type:
- Returns:
Tensor with WMAPE.
Example
>>> import torch >>> _ = torch.manual_seed(42) >>> preds = torch.randn(20,) >>> target = torch.randn(20,) >>> weighted_mean_absolute_percentage_error(preds, target) tensor(1.3967)
Retrieval Fall-Out¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalFallOut(empty_target_action='pos', ignore_index=None, top_k=None, **kwargs)[source]¶
Compute Fall-out.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:fo@k
(Tensor
): A tensor with the computed metric
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a negative
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default: None, which considers them all)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.
Example
>>> from torchmetrics.retrieval import RetrievalFallOut >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> fo = RetrievalFallOut(top_k=2) >>> fo(preds, target, indexes=indexes) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalFallOut >>> # Example plotting a single value >>> metric = RetrievalFallOut() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalFallOut >>> # Example plotting multiple values >>> metric = RetrievalFallOut() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_fall_out(preds, target, top_k=None)[source]¶
Compute the Fall-out for information retrieval, as explained in IR Fall-out.
Fall-out is the fraction of non-relevant documents retrieved among all the non-relevant documents.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised. If you want to measure Fall-out@K,top_k
must be a positive integer.- Parameters:
- Return type:
- Returns:
A single-value tensor with the fall-out (at
top_k
) of the predictionspreds
w.r.t. the labelstarget
- Raises:
ValueError – If
top_k
parameter is not None or an integer larger than 0
Example
>>> from torchmetrics.functional import retrieval_fall_out >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_fall_out(preds, target, top_k=2) tensor(1.)
Retrieval Hit Rate¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalHitRate(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]¶
Compute IR HitRate.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:hr@k
(Tensor
): A single-value tensor with the hit rate (attop_k
) of the predictionspreds
w.r.t. the labelstarget
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default:None
, which considers them all)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalHitRate >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([True, False, False, False, True, False, True]) >>> hr2 = RetrievalHitRate(top_k=2) >>> hr2(preds, target, indexes=indexes) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalHitRate >>> # Example plotting a single value >>> metric = RetrievalHitRate() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalHitRate >>> # Example plotting multiple values >>> metric = RetrievalHitRate() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_hit_rate(preds, target, top_k=None)[source]¶
Compute the hit rate for information retrieval.
The hit rate is 1.0 if there is at least one relevant document among all the top k retrieved documents.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised. If you want to measure HitRate@K,top_k
must be a positive integer.- Parameters:
- Return type:
- Returns:
- A single-value tensor with the hit rate (at
top_k
) of the predictionspreds
w.r.t. the labels target
.
- A single-value tensor with the hit rate (at
- Raises:
ValueError – If
top_k
parameter is not None or an integer larger than 0
Example
>>> from torch import tensor >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_hit_rate(preds, target, top_k=2) tensor(1.)
Retrieval Mean Average Precision (MAP)¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalMAP(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]¶
Compute Mean Average Precision.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:map@k
(Tensor
): A single-value tensor with the mean average precision (MAP) of the predictionspreds
w.r.t. the labelstarget
.
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default:None
, which considers them all)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalMAP >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> rmap = RetrievalMAP() >>> rmap(preds, target, indexes=indexes) tensor(0.7917)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalMAP >>> # Example plotting a single value >>> metric = RetrievalMAP() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalMAP >>> # Example plotting multiple values >>> metric = RetrievalMAP() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_average_precision(preds, target, top_k=None)[source]¶
Compute average precision (for information retrieval), as explained in IR Average precision.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised.- Parameters:
- Return type:
- Returns:
a single-value tensor with the average precision (AP) of the predictions
preds
w.r.t. the labelstarget
.- Raises:
ValueError – If
top_k
is notNone
or an integer larger than 0.
Example
>>> from torchmetrics.functional.retrieval import retrieval_average_precision >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_average_precision(preds, target) tensor(0.8333)
Retrieval Mean Reciprocal Rank (MRR)¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalMRR(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]¶
Compute Mean Reciprocal Rank.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:mrr@k
(Tensor
): A single-value tensor with the reciprocal rank (RR) of the predictionspreds
w.r.t. the labelstarget
.
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default:None
, which considers them all)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalMRR >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> mrr = RetrievalMRR() >>> mrr(preds, target, indexes=indexes) tensor(0.7500)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalMRR >>> # Example plotting a single value >>> metric = RetrievalMRR() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalMRR >>> # Example plotting multiple values >>> metric = RetrievalMRR() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_reciprocal_rank(preds, target, top_k=None)[source]¶
Compute reciprocal rank (for information retrieval). See Mean Reciprocal Rank.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
, 0 is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised.- Parameters:
- Return type:
- Returns:
a single-value tensor with the reciprocal rank (RR) of the predictions
preds
wrt the labelstarget
.- Raises:
ValueError – If
top_k
is notNone
or an integer larger than 0.
Example
>>> from torchmetrics.functional.retrieval import retrieval_reciprocal_rank >>> preds = torch.tensor([0.2, 0.3, 0.5]) >>> target = torch.tensor([False, True, False]) >>> retrieval_reciprocal_rank(preds, target) tensor(0.5000)
Retrieval Normalized DCG¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalNormalizedDCG(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]¶
Compute Normalized Discounted Cumulative Gain.
Works with binary or positive integer target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:ndcg@k
(Tensor
): A single-value tensor with the nDCG of the predictionspreds
w.r.t. the labelstarget
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default:None
, which considers them all)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalNormalizedDCG >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> ndcg = RetrievalNormalizedDCG() >>> ndcg(preds, target, indexes=indexes) tensor(0.8467)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalNormalizedDCG >>> # Example plotting a single value >>> metric = RetrievalNormalizedDCG() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalNormalizedDCG >>> # Example plotting multiple values >>> metric = RetrievalNormalizedDCG() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_normalized_dcg(preds, target, top_k=None)[source]¶
Compute Normalized Discounted Cumulative Gain (for information retrieval).
preds
andtarget
should be of the same shape and live on the same device.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised.- Parameters:
- Return type:
- Returns:
A single-value tensor with the nDCG of the predictions
preds
w.r.t. the labelstarget
.- Raises:
ValueError – If
top_k
parameter is not None or an integer larger than 0
Example
>>> from torchmetrics.functional.retrieval import retrieval_normalized_dcg >>> preds = torch.tensor([.1, .2, .3, 4, 70]) >>> target = torch.tensor([10, 0, 0, 1, 5]) >>> retrieval_normalized_dcg(preds, target) tensor(0.6957)
Retrieval Precision¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalPrecision(empty_target_action='neg', ignore_index=None, top_k=None, adaptive_k=False, **kwargs)[source]¶
Compute IR Precision.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:p@k
(Tensor
): A single-value tensor with the precision (attop_k
) of the predictionspreds
w.r.t. the labelstarget
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default:None
, which considers them all)adaptive_k (
bool
) – Adjusttop_k
tomin(k, number of documents)
for each querykwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.ValueError – If
adaptive_k
is not boolean.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalPrecision >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> p2 = RetrievalPrecision(top_k=2) >>> p2(preds, target, indexes=indexes) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalPrecision >>> # Example plotting a single value >>> metric = RetrievalPrecision() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalPrecision >>> # Example plotting multiple values >>> metric = RetrievalPrecision() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_precision(preds, target, top_k=None, adaptive_k=False)[source]¶
Compute the precision metric for information retrieval.
Precision is the fraction of relevant documents among all the retrieved documents.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised. If you want to measure Precision@K,top_k
must be a positive integer.- Parameters:
preds (
Tensor
) – estimated probabilities of each document to be relevant.target (
Tensor
) – ground truth about each document being relevant or not.top_k (
Optional
[int
]) – consider only the top k elements (default:None
, which considers them all)adaptive_k (
bool
) – adjust k to min(k, number of documents) for each query
- Return type:
- Returns:
- A single-value tensor with the precision (at
top_k
) of the predictionspreds
w.r.t. the labels target
.
- A single-value tensor with the precision (at
- Raises:
ValueError – If
top_k
is not None or an integer larger than 0.ValueError – If
adaptive_k
is not boolean.
Example
>>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_precision(preds, target, top_k=2) tensor(0.5000)
Precision Recall Curve¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalPrecisionRecallCurve(max_k=None, adaptive_k=False, empty_target_action='neg', ignore_index=None, **kwargs)[source]¶
Compute precision-recall pairs for different k (from 1 to max_k).
In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by the top k retrieved documents. Recall is the fraction of relevant documents retrieved among all the relevant documents. Precision is the fraction of relevant documents among all the retrieved documents. For each such set, precision and recall values can be plotted to give a recall-precision curve.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:precisions
(Tensor
): A tensor with the fraction of relevant documents among all the retrieved documents.recalls
(Tensor
): A tensor with the fraction of relevant documents retrieved among all the relevant documentstop_k
(Tensor
): A tensor with k from 1 to max_k
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
max_k (
Optional
[int
]) – Calculate recall and precision for all possible top k from 1 to max_k (default: None, which considers all possible top k)adaptive_k (
bool
) – adjust k to min(k, number of documents) for each queryempty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
max_k
parameter is not None or not an integer larger than 0.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalPrecisionRecallCurve >>> indexes = tensor([0, 0, 0, 0, 1, 1, 1]) >>> preds = tensor([0.4, 0.01, 0.5, 0.6, 0.2, 0.3, 0.5]) >>> target = tensor([True, False, False, True, True, False, True]) >>> r = RetrievalPrecisionRecallCurve(max_k=4) >>> precisions, recalls, top_k = r(preds, target, indexes=indexes) >>> precisions tensor([1.0000, 0.5000, 0.6667, 0.5000]) >>> recalls tensor([0.5000, 0.5000, 1.0000, 1.0000]) >>> top_k tensor([1, 2, 3, 4])
- plot(curve=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalPrecisionRecallCurve >>> # Example plotting a single value >>> metric = RetrievalPrecisionRecallCurve() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_precision_recall_curve(preds, target, max_k=None, adaptive_k=False)[source]¶
Compute precision-recall pairs for different k (from 1 to max_k).
In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by the top k retrieved documents.
Recall is the fraction of relevant documents retrieved among all the relevant documents. Precision is the fraction of relevant documents among all the retrieved documents.
For each such set, precision and recall values can be plotted to give a recall-precision curve.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised.- Parameters:
preds (
Tensor
) – estimated probabilities of each document to be relevant.target (
Tensor
) – ground truth about each document being relevant or not.max_k (
Optional
[int
]) – Calculate recall and precision for all possible top k from 1 to max_k (default: None, which considers all possible top k)adaptive_k (
bool
) – adjust max_k to min(max_k, number of documents) for each query
- Return type:
- Returns:
Tensor with the precision values for each k (at
top_k
) from 1 to max_k Tensor with the recall values for each k (attop_k
) from 1 to max_k Tensor with all possibles k- Raises:
ValueError – If
max_k
is not None or an integer larger than 0.ValueError – If
adaptive_k
is not boolean.
Example
>>> from torch import tensor >>> from torchmetrics.functional import retrieval_precision_recall_curve >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> precisions, recalls, top_k = retrieval_precision_recall_curve(preds, target, max_k=2) >>> precisions tensor([1.0000, 0.5000]) >>> recalls tensor([0.5000, 0.5000]) >>> top_k tensor([1, 2])
Retrieval R-Precision¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalRPrecision(empty_target_action='neg', ignore_index=None, **kwargs)[source]¶
Compute IR R-Precision.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:rp
(Tensor
): A single-value tensor with the r-precision of the predictionspreds
w.r.t. the labelstarget
.
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalRPrecision >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> p2 = RetrievalRPrecision() >>> p2(preds, target, indexes=indexes) tensor(0.7500)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalRPrecision >>> # Example plotting a single value >>> metric = RetrievalRPrecision() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalRPrecision >>> # Example plotting multiple values >>> metric = RetrievalRPrecision() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_r_precision(preds, target)[source]¶
Compute the r-precision metric for information retrieval.
R-Precision is the fraction of relevant documents among all the top
k
retrieved documents wherek
is equal to the total number of relevant documents.preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised. If you want to measure Precision@K,top_k
must be a positive integer.- Parameters:
- Return type:
- Returns:
A single-value tensor with the r-precision of the predictions
preds
w.r.t. the labelstarget
.
Example
>>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_r_precision(preds, target) tensor(0.5000)
Retrieval Recall¶
Module Interface¶
- class torchmetrics.retrieval.RetrievalRecall(empty_target_action='neg', ignore_index=None, top_k=None, **kwargs)[source]¶
Compute IR Recall.
Works with binary target data. Accepts float predictions from a model output.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
target
(Tensor
): A long or bool tensor of shape(N, ...)
indexes
(Tensor
): A long tensor of shape(N, ...)
which indicate to which query a prediction belongs
As output to
forward
andcompute
the metric returns the following output:r@k
(Tensor
): A single-value tensor with the recall (attop_k
) of the predictionspreds
w.r.t. the labelstarget
All
indexes
,preds
andtarget
must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape(N, M)
is treated as(N * M, )
. Predictions will be first grouped byindexes
and then will be computed as the mean of the metric over each query.- Parameters:
empty_target_action (
str
) –Specify what to do with queries that do not have at least a positive
target
. Choose from:'neg'
: those queries count as0.0
(default)'pos'
: those queries count as1.0
'skip'
: skip those queries; if all queries are skipped,0.0
is returned'error'
: raise aValueError
ignore_index (
Optional
[int
]) – Ignore predictions where the target is equal to this number.top_k (
Optional
[int
]) – Consider only the top k elements for each query (default: None, which considers them all)kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
empty_target_action
is not one oferror
,skip
,neg
orpos
.ValueError – If
ignore_index
is not None or an integer.ValueError – If
top_k
is notNone
or not an integer greater than 0.
Example
>>> from torch import tensor >>> from torchmetrics.retrieval import RetrievalRecall >>> indexes = tensor([0, 0, 0, 1, 1, 1, 1]) >>> preds = tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) >>> target = tensor([False, False, True, False, True, False, True]) >>> r2 = RetrievalRecall(top_k=2) >>> r2(preds, target, indexes=indexes) tensor(0.7500)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> import torch >>> from torchmetrics.retrieval import RetrievalRecall >>> # Example plotting a single value >>> metric = RetrievalRecall() >>> metric.update(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,))) >>> fig_, ax_ = metric.plot()
>>> import torch >>> from torchmetrics.retrieval import RetrievalRecall >>> # Example plotting multiple values >>> metric = RetrievalRecall() >>> values = [] >>> for _ in range(10): ... values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.retrieval.retrieval_recall(preds, target, top_k=None)[source]¶
Compute the recall metric for information retrieval.
Recall is the fraction of relevant documents retrieved among all the relevant documents.
preds
andtarget
should be of the same shape and live on the same device. If notarget
isTrue
,0
is returned.target
must be either bool or integers andpreds
must befloat
, otherwise an error is raised. If you want to measure Recall@K,top_k
must be a positive integer.- Parameters:
- Return type:
- Returns:
A single-value tensor with the recall (at
top_k
) of the predictionspreds
w.r.t. the labelstarget
.- Raises:
ValueError – If
top_k
parameter is not None or an integer larger than 0
Example
>>> from torchmetrics.functional import retrieval_recall >>> preds = tensor([0.2, 0.3, 0.5]) >>> target = tensor([True, False, True]) >>> retrieval_recall(preds, target, top_k=2) tensor(0.5000)
BERT Score¶
Module Interface¶
- class torchmetrics.text.bert.BERTScore(model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None, **kwargs)[source]¶
Bert_score Evaluating Text Generation for measuring text similarity.
BERT leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks. This implemenation follows the original implementation from BERT_score.
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): An iterable of predicted sentencestarget
(List
): An iterable of reference sentences
As output of
forward
andcompute
the metric returns the following output:score
(Dict
): A dictionary containing the keysprecision
,recall
andf1
with corresponding values
- Parameters:
preds – An iterable of predicted sentences.
target – An iterable of target sentences.
model_type – A name or a model path used to load
transformers
pretrained model.num_layers (
Optional
[int
]) – A layer of representation to use.all_layers (
bool
) – An indication of whether the representation from all model’s layers should be used. Ifall_layers=True
, the argumentnum_layers
is ignored.model (
Optional
[Module
]) – A user’s own model. Must be of torch.nn.Module instance.user_tokenizer (
Optional
[Any
]) – A user’s own tokenizer used with the own model. This must be an instance with the__call__
method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing “input_ids” and “attention_mask” represented byTensor
. It is up to the user’s model of whether “input_ids” is aTensor
of input ids or embedding vectors. This tokenizer must prepend an equivalent of[CLS]
token and append an equivalent of[SEP]
token astransformers
tokenizer does.user_forward_fn (
Optional
[Callable
[[Module
,Dict
[str
,Tensor
]],Tensor
]]) – A user’s own forward function used in a combination withuser_model
. This function must takeuser_model
and a python dictionary of containing"input_ids"
and"attention_mask"
represented byTensor
as an input and return the model’s output represented by the singleTensor
.verbose (
bool
) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.idf (
bool
) – An indication whether normalization using inverse document frequencies should be used.device (
Union
[str
,device
,None
]) – A device to be used for calculation.max_length (
int
) – A maximum length of input sequences. Sequences longer thanmax_length
are to be trimmed.batch_size (
int
) – A batch size used for model processing.num_threads (
int
) – A number of threads to use for a dataloader.return_hash (
bool
) – An indication of whether the correspodninghash_code
should be returned.lang (
str
) – A language of input sentences.rescale_with_baseline (
bool
) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model fromtransformers
model is used, the corresponding baseline is downloaded from the originalbert-score
package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score.baseline_path (
Optional
[str
]) – A path to the user’s own local csv/tsv file with the baseline scale.baseline_url (
Optional
[str
]) – A url path to the user’s own csv/tsv file with the baseline scale.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from pprint import pprint >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> bertscore = BERTScore() >>> pprint(bertscore(preds, target)) {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> metric = BERTScore() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import tensor >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> metric = BERTScore() >>> values = [] >>> for _ in range(10): ... val = metric(preds, target) ... val = {k: tensor(v).mean() for k,v in val.items()} # convert into single value per key ... values.append(val) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.bert.bert_score(preds, target, model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None)[source]¶
Bert_score Evaluating Text Generation for text similirity matching.
This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks.
This implemenation follows the original implementation from BERT_score.
- Parameters:
preds (
Union
[str
,Sequence
[str
],Dict
[str
,Tensor
]]) – Either an iterable of predicted sentences or aDict[input_ids, attention_mask]
.target (
Union
[str
,Sequence
[str
],Dict
[str
,Tensor
]]) – Either an iterable of target sentences or aDict[input_ids, attention_mask]
.model_name_or_path (
Optional
[str
]) – A name or a model path used to loadtransformers
pretrained model.num_layers (
Optional
[int
]) – A layer of representation to use.all_layers (
bool
) – An indication of whether the representation from all model’s layers should be used. Ifall_layers = True
, the argumentnum_layers
is ignored.user_tokenizer (
Optional
[Any
]) – A user’s own tokenizer used with the own model. This must be an instance with the__call__
method. This method must take an iterable of sentences (List[str]
) and must return a python dictionary containing"input_ids"
and"attention_mask"
represented byTensor
. It is up to the user’s model of whether"input_ids"
is aTensor
of input ids or embedding vectors. his tokenizer must prepend an equivalent of[CLS]
token and append an equivalent of[SEP]
token as transformers tokenizer does.user_forward_fn (
Optional
[Callable
[[Module
,Dict
[str
,Tensor
]],Tensor
]]) – A user’s own forward function used in a combination withuser_model
. This function must takeuser_model
and a python dictionary of containing"input_ids"
and"attention_mask"
represented byTensor
as an input and return the model’s output represented by the singleTensor
.verbose (
bool
) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.idf (
bool
) – An indication of whether normalization using inverse document frequencies should be used.device (
Union
[str
,device
,None
]) – A device to be used for calculation.max_length (
int
) – A maximum length of input sequences. Sequences longer thanmax_length
are to be trimmed.batch_size (
int
) – A batch size used for model processing.num_threads (
int
) – A number of threads to use for a dataloader.return_hash (
bool
) – An indication of whether the correspodninghash_code
should be returned.lang (
str
) – A language of input sentences. It is used when the scores are rescaled with a baseline.rescale_with_baseline (
bool
) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model fromtransformers
model is used, the corresponding baseline is downloaded from the originalbert-score
package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_scorebaseline_path (
Optional
[str
]) – A path to the user’s own local csv/tsv file with the baseline scale.baseline_url (
Optional
[str
]) – A url path to the user’s own csv/tsv file with the baseline scale.
- Return type:
- Returns:
Python dictionary containing the keys
precision
,recall
andf1
with corresponding values.- Raises:
ValueError – If
len(preds) != len(target)
.ModuleNotFoundError – If tqdm package is required and not installed.
ModuleNotFoundError – If
transformers
package is required and not installed.ValueError – If
num_layer
is larger than the number of the model layers.ValueError – If invalid input is provided.
Example
>>> from pprint import pprint >>> from torchmetrics.functional.text.bert import bert_score >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> pprint(bert_score(preds, target)) {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
BLEU Score¶
Module Interface¶
- class torchmetrics.text.BLEUScore(n_gram=4, smooth=False, weights=None, **kwargs)[source]¶
Calculate BLEU score of machine translated text with one or more references.
As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of machine translated corpustarget
(Sequence
): An iterable of iterables of reference corpus
As output of
forward
andupdate
the metric returns the following output:bleu
(Tensor
): A tensor with the BLEU Score
- Parameters:
n_gram (
int
) – Gram value ranged from 1 to 4smooth (
bool
) – Whether or not to apply smoothing, see Machine Translation Evolutionkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.weights (
Optional
[Sequence
[float
]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.
- Raises:
ValueError – If a length of a list of weights is not
None
and not equal ton_gram
.
Example
>>> from torchmetrics.text import BLEUScore >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> bleu = BLEUScore() >>> bleu(preds, target) tensor(0.7598)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import BLEUScore >>> metric = BLEUScore() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import BLEUScore >>> metric = BLEUScore() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.bleu_score(preds, target, n_gram=4, smooth=False, weights=None)[source]¶
Calculate BLEU score of machine translated text with one or more references.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of machine translated corpustarget (
Sequence
[Union
[str
,Sequence
[str
]]]) – An iterable of iterables of reference corpusn_gram (
int
) – Gram value ranged from 1 to 4smooth (
bool
) – Whether to apply smoothing - see [2]weights (
Optional
[Sequence
[float
]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.
- Return type:
- Returns:
Tensor with BLEU Score
- Raises:
ValueError – If
preds
andtarget
corpus have different lengths.ValueError – If a length of a list of weights is not
None
and not equal ton_gram
.
Example
>>> from torchmetrics.functional.text import bleu_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> bleu_score(preds, target) tensor(0.7598)
References
[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu BLEU
[2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution
Char Error Rate¶
Module Interface¶
- class torchmetrics.text.CharErrorRate(**kwargs)[source]¶
Character Error Rate (CER) is a metric of the performance of an automatic speech recognition (ASR) system.
This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CharErrorRate of 0 being a perfect score. Character error rate can then be computed as:
\[CharErrorRate = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}\]- where:
\(S\) is the number of substitutions,
\(D\) is the number of deletions,
\(I\) is the number of insertions,
\(C\) is the number of correct characters,
\(N\) is the number of characters in the reference (N=S+D+C).
Compute CharErrorRate score of transcribed segments against references.
As input to
forward
andupdate
the metric accepts the following input:preds
(str
): Transcription(s) to score as a string or list of stringstarget
(str
): Reference(s) for each speech input as a string or list of strings
As output of
forward
andcompute
the metric returns the following output:cer
(Tensor
): A tensor with the Character Error Rate score
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Examples
>>> from torchmetrics.text import CharErrorRate >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> cer = CharErrorRate() >>> cer(preds, target) tensor(0.3415)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import CharErrorRate >>> metric = CharErrorRate() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import CharErrorRate >>> metric = CharErrorRate() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.char_error_rate(preds, target)[source]¶
Compute Character Rrror Rate used for performance of an automatic speech recognition system.
This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CER of 0 being a perfect score.
- Parameters:
- Return type:
- Returns:
Character error rate score
Examples
>>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> char_error_rate(preds=preds, target=target) tensor(0.3415)
ChrF Score¶
Module Interface¶
- class torchmetrics.text.CHRFScore(n_char_order=6, n_word_order=2, beta=2.0, lowercase=False, whitespace=False, return_sentence_level_score=False, **kwargs)[source]¶
Calculate chrf score of machine translated text with one or more references.
This implementation supports both ChrF score computation introduced in chrF score and chrF++ score introduced in chrF++ score. This implementation follows the implmenetaions from https://github.com/m-popovic/chrF and https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py.
As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of hypothesis corpustarget
(Sequence
): An iterable of iterables of reference corpus
As output of
forward
andcompute
the metric returns the following output:chrf
(Tensor
): If return_sentence_level_score=True return a list of sentence-level chrF/chrF++ scores, else return a corpus-level chrF/chrF++ score
- Parameters:
n_char_order (
int
) – A character n-gram order. Ifn_char_order=6
, the metrics refers to the official chrF/chrF++.n_word_order (
int
) – A word n-gram order. Ifn_word_order=2
, the metric refers to the official chrF++. Ifn_word_order=0
, the metric is equivalent to the original ChrF.beta (
float
) – parameter determining an importance of recall w.r.t. precision. Ifbeta=1
, their importance is equal.lowercase (
bool
) – An indication whether to enable case-insesitivity.whitespace (
bool
) – An indication whether keep whitespaces during n-gram extraction.return_sentence_level_score (
bool
) – An indication whether a sentence-level chrF/chrF++ score to be returned.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
n_char_order
is not an integer greater than or equal to 1.ValueError – If
n_word_order
is not an integer greater than or equal to 0.ValueError – If
beta
is smaller than 0.
Example
>>> from torchmetrics.text import CHRFScore >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> chrf = CHRFScore() >>> chrf(preds, target) tensor(0.8640)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import CHRFScore >>> metric = CHRFScore() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import CHRFScore >>> metric = CHRFScore() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.chrf_score(preds, target, n_char_order=6, n_word_order=2, beta=2.0, lowercase=False, whitespace=False, return_sentence_level_score=False)[source]¶
Calculate chrF score of machine translated text with one or more references.
This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in chrF++ score. This implementation follows the implmenetaions from https://github.com/m-popovic/chrF and https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of hypothesis corpus.target (
Sequence
[Union
[str
,Sequence
[str
]]]) – An iterable of iterables of reference corpus.n_char_order (
int
) – A character n-gram order. If n_char_order=6, the metrics refers to the official chrF/chrF++.n_word_order (
int
) – A word n-gram order. If n_word_order=2, the metric refers to the official chrF++. If n_word_order=0, the metric is equivalent to the original chrF.beta (
float
) – A parameter determining an importance of recall w.r.t. precision. If beta=1, their importance is equal.lowercase (
bool
) – An indication whether to enable case-insesitivity.whitespace (
bool
) – An indication whether to keep whitespaces during character n-gram extraction.return_sentence_level_score (
bool
) – An indication whether a sentence-level chrF/chrF++ score to be returned.
- Return type:
- Returns:
A corpus-level chrF/chrF++ score. (Optionally) A list of sentence-level chrF/chrF++ scores if return_sentence_level_score=True.
- Raises:
ValueError – If
n_char_order
is not an integer greater than or equal to 1.ValueError – If
n_word_order
is not an integer greater than or equal to 0.ValueError – If
beta
is smaller than 0.
Example
>>> from torchmetrics.functional.text import chrf_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> chrf_score(preds, target) tensor(0.8640)
References
[1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović chrF score
[2] chrF++: words helping character n-grams by Maja Popović chrF++ score
Edit Distance¶
Module Interface¶
- class torchmetrics.text.EditDistance(substitution_cost=1, reduction='mean', **kwargs)[source]¶
Calculates the Levenshtein edit distance between two sequences.
The edit distance is the number of characters that need to be substituted, inserted, or deleted, to transform the predicted text into the reference text. The lower the distance, the more accurate the model is considered to be.
Implementation is similar to nltk.edit_distance.
As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of hypothesis corpustarget
(Sequence
): An iterable of iterables of reference corpus
As output of
forward
andcompute
the metric returns the following output:eed
(Tensor
): A tensor with the extended edit distance score. If reduction is set to'none'
orNone
, this has shape(N, )
, whereN
is the batch size. Otherwise, this is a scalar.
- Parameters:
substitution_cost (
int
) – The cost of substituting one character for another.reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
]]) –a method to reduce metric score over samples.
'mean'
: takes the mean over samples'sum'
: takes the sum over samplesNone
or'none'
: return the score per sample
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example::
Basic example with two strings. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits:
>>> from torchmetrics.text import EditDistance >>> metric = EditDistance() >>> metric(["rain"], ["shine"]) tensor(3.)
- Example::
Basic example with two strings and substitution cost of 2. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits, where two of them are substitutions:
>>> from torchmetrics.text import EditDistance >>> metric = EditDistance(substitution_cost=2) >>> metric(["rain"], ["shine"]) tensor(5.)
- Example::
Multiple strings example:
>>> from torchmetrics.text import EditDistance >>> metric = EditDistance(reduction=None) >>> metric(["rain", "lnaguaeg"], ["shine", "language"]) tensor([3, 4], dtype=torch.int32) >>> metric = EditDistance(reduction="mean") >>> metric(["rain", "lnaguaeg"], ["shine", "language"]) tensor(3.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import EditDistance >>> metric = EditDistance() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import EditDistance >>> metric = EditDistance() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.edit_distance(preds, target, substitution_cost=1, reduction='mean')[source]¶
Calculates the Levenshtein edit distance between two sequences.
The edit distance is the number of characters that need to be substituted, inserted, or deleted, to transform the predicted text into the reference text. The lower the distance, the more accurate the model is considered to be.
Implementation is similar to nltk.edit_distance.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of predicted texts (strings).target (
Union
[str
,Sequence
[str
]]) – An iterable of reference texts (strings).substitution_cost (
int
) – The cost of substituting one character for another.reduction (
Optional
[Literal
['mean'
,'sum'
,'none'
]]) –a method to reduce metric score over samples.
'mean'
: takes the mean over samples'sum'
: takes the sum over samplesNone
or'none'
: return the score per sample
- Raises:
ValueError – If
preds
andtarget
do not have the same length.ValueError – If
preds
ortarget
contain non-string values.
- Return type:
- Example::
Basic example with two strings. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits:
>>> from torchmetrics.functional.text import edit_distance >>> edit_distance(["rain"], ["shine"]) tensor(3.)
- Example::
Basic example with two strings and substitution cost of 2. Going from “rain” -> “sain” -> “shin” -> “shine” takes 3 edits, where two of them are substitutions:
>>> from torchmetrics.functional.text import edit_distance >>> edit_distance(["rain"], ["shine"], substitution_cost=2) tensor(5.)
- Example::
Multiple strings example:
>>> from torchmetrics.functional.text import edit_distance >>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction=None) tensor([3, 4], dtype=torch.int32) >>> edit_distance(["rain", "lnaguaeg"], ["shine", "language"], reduction="mean") tensor(3.5000)
Extended Edit Distance¶
Module Interface¶
- class torchmetrics.text.ExtendedEditDistance(language='en', return_sentence_level_score=False, alpha=2.0, rho=0.3, deletion=0.2, insertion=1.0, **kwargs)[source]¶
Compute extended edit distance score (ExtendedEditDistance) for strings or list of strings.
The metric utilises the Levenshtein distance and extends it by adding a jump operation.
As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of hypothesis corpustarget
(Sequence
): An iterable of iterables of reference corpus
As output of
forward
andcompute
the metric returns the following output:eed
(Tensor
): A tensor with the extended edit distance score
- Parameters:
language (
Literal
['en'
,'ja'
]) – Language used in sentences. Only supports English (en) and Japanese (ja) for now.return_sentence_level_score (
bool
) – An indication of whether sentence-level EED score is to be returnedalpha (
float
) – optimal jump penalty, penalty for jumps between charactersrho (
float
) – coverage cost, penalty for repetition of charactersdeletion (
float
) – penalty for deletion of characterinsertion (
float
) – penalty for insertion or substitution of characterkwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.text import ExtendedEditDistance >>> preds = ["this is the prediction", "here is an other sample"] >>> target = ["this is the reference", "here is another one"] >>> eed = ExtendedEditDistance() >>> eed(preds=preds, target=target) tensor(0.3078)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import ExtendedEditDistance >>> metric = ExtendedEditDistance() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import ExtendedEditDistance >>> metric = ExtendedEditDistance() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.extended_edit_distance(preds, target, language='en', return_sentence_level_score=False, alpha=2.0, rho=0.3, deletion=0.2, insertion=1.0)[source]¶
Compute extended edit distance score (ExtendedEditDistance) [1] for strings or list of strings.
The metric utilises the Levenshtein distance and extends it by adding a jump operation.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of hypothesis corpus.target (
Sequence
[Union
[str
,Sequence
[str
]]]) – An iterable of iterables of reference corpus.language (
Literal
['en'
,'ja'
]) – Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to enreturn_sentence_level_score (
bool
) – An indication of whether sentence-level EED score is to be returned.alpha (
float
) – optimal jump penalty, penalty for jumps between charactersrho (
float
) – coverage cost, penalty for repetition of charactersdeletion (
float
) – penalty for deletion of characterinsertion (
float
) – penalty for insertion or substitution of character
- Return type:
- Returns:
Extended edit distance score as a tensor
Example
>>> from torchmetrics.functional.text import extended_edit_distance >>> preds = ["this is the prediction", "here is an other sample"] >>> target = ["this is the reference", "here is another one"] >>> extended_edit_distance(preds=preds, target=target) tensor(0.3078)
References
[1] P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”, submitted to WMT 2019. ExtendedEditDistance
InfoLM¶
Module Interface¶
- class torchmetrics.text.infolm.InfoLM(model_name_or_path='bert-base-uncased', temperature=0.25, information_measure='kl_divergence', idf=True, alpha=None, beta=None, device=None, max_length=None, batch_size=64, num_threads=0, verbose=True, return_sentence_level_score=False, **kwargs)[source]¶
Calculate InfoLM.
InfoLM measures a distance/divergence between predicted and reference sentence discrete distribution using one of the following information measures:
L1 distance
L2 distance
L-infinity distance
InfoLM is a family of untrained embedding-based metrics which addresses some famous flaws of standard string-based metrics thanks to the usage of pre-trained masked language models. This family of metrics is mainly designed for summarization and data-to-text tasks.
The implementation of this metric is fully based HuggingFace
transformers
’ package.As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of hypothesis corpustarget
(Sequence
): An iterable of reference corpus
As output of
forward
andcompute
the metric returns the following output:infolm
(Tensor
): If return_sentence_level_score=True return a tuple with a tensor with the corpus-level InfoLM score and a list of sentence-level InfoLM scores, else return a corpus-level InfoLM score
- Parameters:
model_name_or_path (
Union
[str
,PathLike
]) – A name or a model path used to loadtransformers
pretrained model. By default the “bert-base-uncased” model is used.temperature (
float
) – A temperature for calibrating language modelling. For more information, please reference InfoLM paper.information_measure (
Literal
['kl_divergence'
,'alpha_divergence'
,'beta_divergence'
,'ab_divergence'
,'renyi_divergence'
,'l1_distance'
,'l2_distance'
,'l_infinity_distance'
,'fisher_rao_distance'
]) – A name of information measure to be used. Please use one of: [‘kl_divergence’, ‘alpha_divergence’, ‘beta_divergence’, ‘ab_divergence’, ‘renyi_divergence’, ‘l1_distance’, ‘l2_distance’, ‘l_infinity_distance’, ‘fisher_rao_distance’]idf (
bool
) – An indication of whether normalization using inverse document frequencies should be used.alpha (
Optional
[float
]) – Alpha parameter of the divergence used for alpha, AB and Rényi divergence measures.beta (
Optional
[float
]) – Beta parameter of the divergence used for beta and AB divergence measures.device (
Union
[str
,device
,None
]) – A device to be used for calculation.max_length (
Optional
[int
]) – A maximum length of input sequences. Sequences longer thanmax_length
are to be trimmed.batch_size (
int
) – A batch size used for model processing.num_threads (
int
) – A number of threads to use for a dataloader.verbose (
bool
) – An indication of whether a progress bar to be displayed during the embeddings calculation.return_sentence_level_score (
bool
) – An indication whether a sentence-level InfoLM score to be returned.
Example
>>> from torchmetrics.text.infolm import InfoLM >>> preds = ['he read the book because he was interested in world history'] >>> target = ['he was interested in world history because he read the book'] >>> infolm = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) >>> infolm(preds, target) tensor(-0.1784)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text.infolm import InfoLM >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) >>> preds = ['he read the book because he was interested in world history'] >>> target = ['he was interested in world history because he read the book'] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text.infolm import InfoLM >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False) >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.infolm.infolm(preds, target, model_name_or_path='bert-base-uncased', temperature=0.25, information_measure='kl_divergence', idf=True, alpha=None, beta=None, device=None, max_length=None, batch_size=64, num_threads=0, verbose=True, return_sentence_level_score=False)[source]¶
Calculate InfoLM [1].
InfoML corresponds to distance/divergence between predicted and reference sentence discrete distribution using one of the following information measures:
L1 distance
L2 distance
L-infinity distance
InfoLM is a family of untrained embedding-based metrics which addresses some famous flaws of standard string-based metrics thanks to the usage of pre-trained masked language models. This family of metrics is mainly designed for summarization and data-to-text tasks.
If you want to use IDF scaling over the whole dataset, please use the class metric.
The implementation of this metric is fully based HuggingFace transformers’ package.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of hypothesis corpus.target (
Union
[str
,Sequence
[str
]]) – An iterable of reference corpus.model_name_or_path (
Union
[str
,PathLike
]) – A name or a model path used to load transformers pretrained model.temperature (
float
) – A temperature for calibrating language modelling. For more information, please reference InfoLM paper.information_measure (
Literal
['kl_divergence'
,'alpha_divergence'
,'beta_divergence'
,'ab_divergence'
,'renyi_divergence'
,'l1_distance'
,'l2_distance'
,'l_infinity_distance'
,'fisher_rao_distance'
]) – A name of information measure to be used. Please use one of: [‘kl_divergence’, ‘alpha_divergence’, ‘beta_divergence’, ‘ab_divergence’, ‘renyi_divergence’, ‘l1_distance’, ‘l2_distance’, ‘l_infinity_distance’, ‘fisher_rao_distance’]idf (
bool
) – An indication of whether normalization using inverse document frequencies should be used.alpha (
Optional
[float
]) – Alpha parameter of the divergence used for alpha, AB and Rényi divergence measures.beta (
Optional
[float
]) – Beta parameter of the divergence used for beta and AB divergence measures.device (
Union
[str
,device
,None
]) – A device to be used for calculation.max_length (
Optional
[int
]) – A maximum length of input sequences. Sequences longer than max_length are to be trimmed.batch_size (
int
) – A batch size used for model processing.num_threads (
int
) – A number of threads to use for a dataloader.verbose (
bool
) – An indication of whether a progress bar to be displayed during the embeddings calculation.return_sentence_level_score (
bool
) – An indication whether a sentence-level InfoLM score to be returned.
- Return type:
- Returns:
A corpus-level InfoLM score. (Optionally) A list of sentence-level InfoLM scores if return_sentence_level_score=True.
Example
>>> from torchmetrics.functional.text.infolm import infolm >>> preds = ['he read the book because he was interested in world history'] >>> target = ['he was interested in world history because he read the book'] >>> infolm(preds, target, model_name_or_path='google/bert_uncased_L-2_H-128_A-2', idf=False) tensor(-0.1784)
References
[1] InfoLM: A New Metric to Evaluate Summarization & Data2Text Generation by Pierre Colombo, Chloé Clavel and Pablo Piantanida InfoLM
Match Error Rate¶
Module Interface¶
- class torchmetrics.text.MatchErrorRate(**kwargs)[source]¶
Match Error Rate (MER) is a common metric of the performance of an automatic speech recognition system.
This value indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the better the performance of the ASR system with a MatchErrorRate of 0 being a perfect score. Match error rate can then be computed as:
\[mer = \frac{S + D + I}{N + I} = \frac{S + D + I}{S + D + C + I}\]- where:
\(S\) is the number of substitutions,
\(D\) is the number of deletions,
\(I\) is the number of insertions,
\(C\) is the number of correct words,
\(N\) is the number of words in the reference (\(N=S+D+C\)).
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): Transcription(s) to score as a string or list of stringstarget
(List
): Reference(s) for each speech input as a string or list of strings
As output of
forward
andcompute
the metric returns the following output:mer
(Tensor
): A tensor with the match error rate
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Examples
>>> from torchmetrics.text import MatchErrorRate >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> mer = MatchErrorRate() >>> mer(preds, target) tensor(0.4444)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import MatchErrorRate >>> metric = MatchErrorRate() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import MatchErrorRate >>> metric = MatchErrorRate() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.match_error_rate(preds, target)[source]¶
Match error rate is a metric of the performance of an automatic speech recognition system.
This value indicates the percentage of words that were incorrectly predicted and inserted. The lower the value, the better the performance of the ASR system with a MatchErrorRate of 0 being a perfect score.
- Parameters:
- Return type:
- Returns:
Match error rate score
Examples
>>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> match_error_rate(preds=preds, target=target) tensor(0.4444)
Perplexity¶
Module Interface¶
- class torchmetrics.text.perplexity.Perplexity(ignore_index=None, **kwargs)[source]¶
Perplexity measures how well a language model predicts a text sample.
It’s calculated as the average number of bits per word a model needs to represent the sample.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Logits or a unnormalized score assigned to each token in a sequence with shape[batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.
target
(Tensor
): Ground truth values with a shape [batch_size, seq_len]
As output of
forward
andcompute
the metric returns the following output:perp
(Tensor
): A tensor with the perplexity score
- Parameters:
Examples
>>> from torchmetrics.text import Perplexity >>> import torch >>> gen = torch.manual_seed(42) >>> preds = torch.rand(2, 8, 5, generator=gen) >>> target = torch.randint(5, (2, 8), generator=gen) >>> target[0, 6:] = -100 >>> perp = Perplexity(ignore_index=-100) >>> perp(preds, target) tensor(5.8540)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.text import Perplexity >>> metric = Perplexity() >>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.text import Perplexity >>> metric = Perplexity() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.perplexity.perplexity(preds, target, ignore_index=None)[source]¶
Perplexity measures how well a language model predicts a text sample.
This metric is calculated as the average number of bits per word a model needs to represent the sample.
- Parameters:
preds (
Tensor
) – Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.target (
Tensor
) – Ground truth values with a shape [batch_size, seq_len].ignore_index (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.
- Return type:
- Returns:
Perplexity value
Examples
>>> import torch >>> gen = torch.manual_seed(42) >>> preds = torch.rand(2, 8, 5, generator=gen) >>> target = torch.randint(5, (2, 8), generator=gen) >>> target[0, 6:] = -100 >>> perplexity(preds, target, ignore_index=-100) tensor(5.8540)
ROUGE Score¶
Module Interface¶
- class torchmetrics.text.rouge.ROUGEScore(use_stemmer=False, normalizer=None, tokenizer=None, accumulate='best', rouge_keys=('rouge1', 'rouge2', 'rougeL', 'rougeLsum'), **kwargs)[source]¶
Calculate Rouge Score, used for automatic summarization.
This implementation should imitate the behaviour of the
rouge-score
package Python ROUGE ImplementationAs input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of predicted sentences or a single predicted sentencetarget
(Sequence
): An iterable of target sentences or an iterable of interables of target sentences or a single target sentence
As output of
forward
andcompute
the metric returns the following output:rouge
(Dict
): A dictionary of tensor rouge scores for each input str rouge key
- Parameters:
use_stemmer (
bool
) – Use Porter stemmer to strip word suffixes to improve matching.normalizer (
Optional
[Callable
[[str
],str
]]) – A user’s own normalizer function. If this isNone
, replacing any non-alpha-numeric characters with spaces is default. This function must take astr
and return astr
.tokenizer (
Optional
[Callable
[[str
],Sequence
[str
]]]) – A user’s own tokenizer function. If this isNone
, spliting by spaces is default This function must take astr
and returnSequence[str]
accumulate (
Literal
['avg'
,'best'
]) –Useful in case of multi-reference rouge score.
avg
takes the avg of all references with respect to predictionsbest
takes the best fmeasure score obtained between prediction and multiple corresponding references.
rouge_keys (
Union
[str
,Tuple
[str
,...
]]) – A list of rouge types to calculate. Keys that are allowed arerougeL
,rougeLsum
, androuge1
throughrouge9
.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.text.rouge import ROUGEScore >>> preds = "My name is John" >>> target = "Is your name John" >>> rouge = ROUGEScore() >>> from pprint import pprint >>> pprint(rouge(preds, target)) {'rouge1_fmeasure': tensor(0.7500), 'rouge1_precision': tensor(0.7500), 'rouge1_recall': tensor(0.7500), 'rouge2_fmeasure': tensor(0.), 'rouge2_precision': tensor(0.), 'rouge2_recall': tensor(0.), 'rougeL_fmeasure': tensor(0.5000), 'rougeL_precision': tensor(0.5000), 'rougeL_recall': tensor(0.5000), 'rougeLsum_fmeasure': tensor(0.5000), 'rougeLsum_precision': tensor(0.5000), 'rougeLsum_recall': tensor(0.5000)}
- Raises:
ValueError – If the python packages
nltk
is not installed.ValueError – If any of the
rouge_keys
does not belong to the allowed set of keys.
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text.rouge import ROUGEScore >>> metric = ROUGEScore() >>> preds = "My name is John" >>> target = "Is your name John" >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text.rouge import ROUGEScore >>> metric = ROUGEScore() >>> preds = "My name is John" >>> target = "Is your name John" >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.rouge.rouge_score(preds, target, accumulate='best', use_stemmer=False, normalizer=None, tokenizer=None, rouge_keys=('rouge1', 'rouge2', 'rougeL', 'rougeLsum'))[source]¶
Calculate Calculate Rouge Score , used for automatic summarization.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of predicted sentences or a single predicted sentence.target (
Union
[str
,Sequence
[str
],Sequence
[Sequence
[str
]]]) – An iterable of iterables of target sentences or an iterable of target sentences or a single target sentence.accumulate (
Literal
['avg'
,'best'
]) –Useful incase of multi-reference rouge score.
avg
takes the avg of all references with respect to predictionsbest
takes the best fmeasure score obtained between prediction and multiple corresponding references.
use_stemmer (
bool
) – Use Porter stemmer to strip word suffixes to improve matching.normalizer (
Optional
[Callable
[[str
],str
]]) – A user’s own normalizer function. If this isNone
, replacing any non-alpha-numeric characters with spaces is default. This function must take astr
and return astr
.tokenizer (
Optional
[Callable
[[str
],Sequence
[str
]]]) – A user’s own tokenizer function. If this isNone
, spliting by spaces is default This function must take astr
and returnSequence[str]
rouge_keys (
Union
[str
,Tuple
[str
,...
]]) – A list of rouge types to calculate. Keys that are allowed arerougeL
,rougeLsum
, androuge1
throughrouge9
.
- Return type:
- Returns:
Python dictionary of rouge scores for each input rouge key.
Example
>>> from torchmetrics.functional.text.rouge import rouge_score >>> preds = "My name is John" >>> target = "Is your name John" >>> from pprint import pprint >>> pprint(rouge_score(preds, target)) {'rouge1_fmeasure': tensor(0.7500), 'rouge1_precision': tensor(0.7500), 'rouge1_recall': tensor(0.7500), 'rouge2_fmeasure': tensor(0.), 'rouge2_precision': tensor(0.), 'rouge2_recall': tensor(0.), 'rougeL_fmeasure': tensor(0.5000), 'rougeL_precision': tensor(0.5000), 'rougeL_recall': tensor(0.5000), 'rougeLsum_fmeasure': tensor(0.5000), 'rougeLsum_precision': tensor(0.5000), 'rougeLsum_recall': tensor(0.5000)}
- Raises:
ModuleNotFoundError – If the python package
nltk
is not installed.ValueError – If any of the
rouge_keys
does not belong to the allowed set of keys.
References
[1] ROUGE: A Package for Automatic Evaluation of Summaries by Chin-Yew Lin. https://aclanthology.org/W04-1013/
Sacre BLEU Score¶
Module Interface¶
- class torchmetrics.text.SacreBLEUScore(n_gram=4, smooth=False, tokenize='13a', lowercase=False, weights=None, **kwargs)[source]¶
Calculate BLEU score of machine translated text with one or more references.
This implementation follows the behaviour of SacreBLEU. The SacreBLEU implementation differs from the NLTK BLEU implementation in tokenization techniques.
As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of machine translated corpustarget
(Sequence
): An iterable of iterables of reference corpus
As output of
forward
andcompute
the metric returns the following output:sacre_bleu
(Tensor
): A tensor with the SacreBLEU Score
- Parameters:
n_gram (
int
) – Gram value ranged from 1 to 4tokenize (
Literal
['none'
,'13a'
,'zh'
,'intl'
,'char'
]) – Tokenization technique to be used. Supported tokenization:['none', '13a', 'zh', 'intl', 'char']
lowercase (
bool
) – IfTrue
, BLEU score over lowercased text is calculated.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.weights (
Optional
[Sequence
[float
]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.
- Raises:
ValueError – If
tokenize
not one of ‘none’, ‘13a’, ‘zh’, ‘intl’ or ‘char’ValueError – If
tokenize
is set to ‘intl’ and regex is not installedValueError – If a length of a list of weights is not
None
and not equal ton_gram
.
Example
>>> from torchmetrics.text import SacreBLEUScore >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> sacre_bleu = SacreBLEUScore() >>> sacre_bleu(preds, target) tensor(0.7598)
Additional References:
Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import SacreBLEUScore >>> metric = SacreBLEUScore() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import SacreBLEUScore >>> metric = SacreBLEUScore() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.sacre_bleu_score(preds, target, n_gram=4, smooth=False, tokenize='13a', lowercase=False, weights=None)[source]¶
Calculate BLEU score [1] of machine translated text with one or more references.
This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu.
- Parameters:
preds (
Sequence
[str
]) – An iterable of machine translated corpustarget (
Sequence
[Sequence
[str
]]) – An iterable of iterables of reference corpusn_gram (
int
) – Gram value ranged from 1 to 4smooth (
bool
) – Whether to apply smoothing - see [2]tokenize (
Literal
['none'
,'13a'
,'zh'
,'intl'
,'char'
]) – Tokenization technique to be used. Supported tokenization: [‘none’, ‘13a’, ‘zh’, ‘intl’, ‘char’]lowercase (
bool
) – IfTrue
, BLEU score over lowercased text is calculated.weights (
Optional
[Sequence
[float
]]) – Weights used for unigrams, bigrams, etc. to calculate BLEU score. If not provided, uniform weights are used.
- Return type:
- Returns:
Tensor with BLEU Score
- Raises:
ValueError – If
preds
andtarget
corpus have different lengths.ValueError – If a length of a list of weights is not
None
and not equal ton_gram
.
Example
>>> from torchmetrics.functional.text import sacre_bleu_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> sacre_bleu_score(preds, target) tensor(0.7598)
References
[1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu BLEU
[2] A Call for Clarity in Reporting BLEU Scores by Matt Post.
[3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och Machine Translation Evolution
SQuAD¶
Module Interface¶
- class torchmetrics.text.SQuAD(**kwargs)[source]¶
Calculate SQuAD Metric which is a metric for evaluating question answering models.
This metric corresponds to the scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
As input to
forward
andupdate
the metric accepts the following input:preds
(Dict
): A Dictionary or List of Dictionary-s that mapid
andprediction_text
to the respective valuesExample
prediction
:{"prediction_text": "TorchMetrics is awesome", "id": "123"}
target
(Dict
): A Dictionary or List of Dictionary-s that contain theanswers
andid
in the SQuAD Format.Example
target
:{ 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], 'id': '1', }
Reference SQuAD Format:
{ 'answers': {'answer_start': [1], 'text': ['This is a test text']}, 'context': 'This is a test context.', 'id': '1', 'question': 'Is this a test?', 'title': 'train test' }
As output of
forward
andcompute
the metric returns the following output:squad
(Dict
): A dictionary containing the F1 score (key: “f1”),and Exact match score (key: “exact_match”) for the batch.
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.text import SQuAD >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> squad = SQuAD() >>> squad(preds, target) {'exact_match': tensor(100.), 'f1': tensor(100.)}
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import SQuAD >>> metric = SQuAD() >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import SQuAD >>> metric = SQuAD() >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.squad(preds, target)[source]¶
Calculate SQuAD Metric .
- Parameters:
preds (
Union
[Dict
[str
,str
],List
[Dict
[str
,str
]]]) –A Dictionary or List of Dictionary-s that map id and prediction_text to the respective values.
Example prediction:
{"prediction_text": "TorchMetrics is awesome", "id": "123"}
target (
Union
[Dict
[str
,Union
[str
,Dict
[str
,Union
[List
[str
],List
[int
]]]]],List
[Dict
[str
,Union
[str
,Dict
[str
,Union
[List
[str
],List
[int
]]]]]]]) –A Dictionary or List of Dictionary-s that contain the answers and id in the SQuAD Format.
Example target:
{ 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], 'id': '1', }
Reference SQuAD Format:
{ 'answers': {'answer_start': [1], 'text': ['This is a test text']}, 'context': 'This is a test context.', 'id': '1', 'question': 'Is this a test?', 'title': 'train test' }
- Return type:
- Returns:
Dictionary containing the F1 score, Exact match score for the batch.
Example
>>> from torchmetrics.functional.text.squad import squad >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}] >>> squad(preds, target) {'exact_match': tensor(100.), 'f1': tensor(100.)}
- Raises:
KeyError – If the required keys are missing in either predictions or targets.
References
[1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang SQuAD Metric .
Translation Edit Rate (TER)¶
Module Interface¶
- class torchmetrics.text.TranslationEditRate(normalize=False, no_punctuation=False, lowercase=True, asian_support=False, return_sentence_level_score=False, **kwargs)[source]¶
Calculate Translation edit rate (TER) of machine translated text with one or more references.
This implementation follows the one from SacreBleu_ter, which is a near-exact reimplementation of the Tercom algorithm, produces identical results on all “sane” outputs.
As input to
forward
andupdate
the metric accepts the following input:preds
(Sequence
): An iterable of hypothesis corpustarget
(Sequence
): An iterable of iterables of reference corpus
As output of
forward
andcompute
the metric returns the following output:ter
(Tensor
): ifreturn_sentence_level_score=True
return a corpus-level translation edit rate with a list of sentence-level translation_edit_rate, else return a corpus-level translation edit rate
- Parameters:
normalize (
bool
) – An indication whether a general tokenization to be applied.no_punctuation (
bool
) – An indication whteher a punctuation to be removed from the sentences.lowercase (
bool
) – An indication whether to enable case-insesitivity.asian_support (
bool
) – An indication whether asian characters to be processed.return_sentence_level_score (
bool
) – An indication whether a sentence-level TER to be returned.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.text import TranslationEditRate >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> ter = TranslationEditRate() >>> ter(preds, target) tensor(0.1538)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import TranslationEditRate >>> metric = TranslationEditRate() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import TranslationEditRate >>> metric = TranslationEditRate() >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.translation_edit_rate(preds, target, normalize=False, no_punctuation=False, lowercase=True, asian_support=False, return_sentence_level_score=False)[source]¶
Calculate Translation edit rate (TER) of machine translated text with one or more references.
This implementation follows the implmenetaions from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/ter.py. The sacrebleu implmenetation is a near-exact reimplementation of the Tercom algorithm, produces identical results on all “sane” outputs.
- Parameters:
preds (
Union
[str
,Sequence
[str
]]) – An iterable of hypothesis corpus.target (
Sequence
[Union
[str
,Sequence
[str
]]]) – An iterable of iterables of reference corpus.normalize (
bool
) – An indication whether a general tokenization to be applied.no_punctuation (
bool
) – An indication whteher a punctuation to be removed from the sentences.lowercase (
bool
) – An indication whether to enable case-insesitivity.asian_support (
bool
) – An indication whether asian characters to be processed.return_sentence_level_score (
bool
) – An indication whether a sentence-level TER to be returned.
- Return type:
- Returns:
A corpus-level translation edit rate (TER). (Optionally) A list of sentence-level translation_edit_rate (TER) if return_sentence_level_score=True.
Example
>>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> translation_edit_rate(preds, target) tensor(0.1538)
References
[1] A Study of Translation Edit Rate with Targeted Human Annotation by Mathew Snover, Bonnie Dorr, Richard Schwartz, Linnea Micciulla and John Makhoul TER
Word Error Rate¶
Module Interface¶
- class torchmetrics.text.WordErrorRate(**kwargs)[source]¶
Word error rate (WordErrorRate) is a common metric of the performance of an automatic speech recognition.
This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score. Word error rate can then be computed as:
\[WER = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C}\]where: - \(S\) is the number of substitutions, - \(D\) is the number of deletions, - \(I\) is the number of insertions, - \(C\) is the number of correct words, - \(N\) is the number of words in the reference (\(N=S+D+C\)).
Compute WER score of transcribed segments against references.
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): Transcription(s) to score as a string or list of stringstarget
(List
): Reference(s) for each speech input as a string or list of strings
As output of
forward
andcompute
the metric returns the following output:wer
(Tensor
): A tensor with the Word Error Rate score
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Examples
>>> from torchmetrics.text import WordErrorRate >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> wer = WordErrorRate() >>> wer(preds, target) tensor(0.5000)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import WordErrorRate >>> metric = WordErrorRate() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import WordErrorRate >>> metric = WordErrorRate() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.word_error_rate(preds, target)[source]¶
Word error rate (WordErrorRate) is a common metric of the performance of an automatic speech recognition system.
This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score.
- Parameters:
- Return type:
- Returns:
Word error rate score
Examples
>>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> word_error_rate(preds=preds, target=target) tensor(0.5000)
Word Info. Lost¶
Module Interface¶
- class torchmetrics.text.WordInfoLost(**kwargs)[source]¶
Word Information Lost (WIL) is a metric of the performance of an automatic speech recognition system.
This value indicates the percentage of words that were incorrectly predicted between a set of ground-truth sentences and a set of hypothesis sentences. The lower the value, the better the performance of the ASR system with a WordInfoLost of 0 being a perfect score. Word Information Lost rate can then be computed as:
\[wil = 1 - \frac{C}{N} + \frac{C}{P}\]where:
\(C\) is the number of correct words,
\(N\) is the number of words in the reference
\(P\) is the number of words in the prediction
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): Transcription(s) to score as a string or list of stringstarget
(List
): Reference(s) for each speech input as a string or list of strings
As output of
forward
andcompute
the metric returns the following output:wil
(Tensor
): A tensor with the Word Information Lost score
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Examples
>>> from torchmetrics.text import WordInfoLost >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> wil = WordInfoLost() >>> wil(preds, target) tensor(0.6528)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import WordInfoLost >>> metric = WordInfoLost() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import WordInfoLost >>> metric = WordInfoLost() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.word_information_lost(preds, target)[source]¶
Word Information Lost rate is a metric of the performance of an automatic speech recognition system.
This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score.
- Parameters:
- Return type:
- Returns:
Word Information Lost rate
Examples
>>> from torchmetrics.functional.text import word_information_lost >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> word_information_lost(preds, target) tensor(0.6528)
Word Info. Preserved¶
Module Interface¶
- class torchmetrics.text.WordInfoPreserved(**kwargs)[source]¶
Word Information Preserved (WIP) is a metric of the performance of an automatic speech recognition system.
This value indicates the percentage of words that were correctly predicted between a set of ground- truth sentences and a set of hypothesis sentences. The higher the value, the better the performance of the ASR system with a WordInfoPreserved of 1 being a perfect score. Word Information Preserved rate can then be computed as:
\[wip = \frac{C}{N} + \frac{C}{P}\]where:
\(C\) is the number of correct words,
\(N\) is the number of words in the reference
\(P\) is the number of words in the prediction
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): Transcription(s) to score as a string or list of stringstarget
(List
): Reference(s) for each speech input as a string or list of strings
As output of
forward
andcompute
the metric returns the following output:wip
(Tensor
): A tensor with the Word Information Preserved score
- Parameters:
kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Examples
>>> from torchmetrics.text import WordInfoPreserved >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> wip = WordInfoPreserved() >>> wip(preds, target) tensor(0.3472)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import WordInfoPreserved >>> metric = WordInfoPreserved() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import WordInfoPreserved >>> metric = WordInfoPreserved() >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.word_information_preserved(preds, target)[source]¶
Word Information Preserved rate is a metric of the performance of an automatic speech recognition system.
This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information preserved rate of 0 being a perfect score.
- Parameters:
- Return type:
- Returns:
Word Information preserved rate
Examples
>>> from torchmetrics.functional.text import word_information_preserved >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> word_information_preserved(preds, target) tensor(0.3472)
Bootstrapper¶
Module Interface¶
- class torchmetrics.wrappers.BootStrapper(base_metric, num_bootstraps=10, mean=True, std=True, quantile=None, raw=False, sampling_strategy='poisson', **kwargs)[source]¶
Using Turn a Metric into a Bootstrapped.
That can automate the process of getting confidence intervals for metric values. This wrapper class basically keeps multiple copies of the same base metric in memory and whenever
update
orforward
is called, all input tensors are resampled (with replacement) along the first dimension.- Parameters:
base_metric (
Metric
) – base metric class to wrapnum_bootstraps (
int
) – number of copies to make of the base metric for bootstrappingmean (
bool
) – ifTrue
return the mean of the bootstrapsstd (
bool
) – ifTrue
return the standard diviation of the bootstrapsquantile (
Union
[float
,Tensor
,None
]) – if given, returns the quantile of the bootstraps. Can only be used with pytorch version 1.6 or higherraw (
bool
) – ifTrue
, return all bootstrapped valuessampling_strategy (
str
) – Determines how to produce bootstrapped samplings. Either'poisson'
ormultinomial
. If'possion'
is chosen, the number of times each sample will be included in the bootstrap will be given by \(n\sim Poisson(\lambda=1)\), which approximates the true bootstrap distribution when the number of samples is large. If'multinomial'
is chosen, we will apply true bootstrapping at the batch level to approximate bootstrapping over the hole dataset.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Example::
>>> from pprint import pprint >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> _ = torch.manual_seed(123) >>> base_metric = MulticlassAccuracy(num_classes=5, average='micro') >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) {'mean': tensor(0.2205), 'std': tensor(0.0859)}
- compute()[source]¶
Compute the bootstrapped metric values.
Always returns a dict of tensors, which can contain the following keys:
mean
,std
,quantile
andraw
depending on how the class was initialized.
- forward(*args, **kwargs)[source]¶
Use the original forward method of the base metric class.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.regression import MeanSquaredError >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) >>> metric.update(torch.randn(100,), torch.randn(100,)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import BootStrapper >>> from torchmetrics.regression import MeanSquaredError >>> metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randn(100,), torch.randn(100,))) >>> fig_, ax_ = metric.plot(values)
Classwise Wrapper¶
Module Interface¶
- class torchmetrics.wrappers.ClasswiseWrapper(metric, labels=None, prefix=None, postfix=None)[source]¶
Wrapper metric for altering the output of classification metrics.
This metric works together with classification metrics that returns multiple values (one value per class) such that label information can be automatically included in the output.
- Parameters:
metric (
Metric
) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.labels (
Optional
[List
[str
]]) – list of strings indicating the different classes.prefix (
Optional
[str
]) – string that is prepended to the metric names.postfix (
Optional
[str
]) – string that is appended to the metric names.
- Example::
Basic example where the ouput of a metric is unwrapped into a dictionary with the class index as keys:
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'multiclassaccuracy_0': tensor(0.5000), 'multiclassaccuracy_1': tensor(0.7500), 'multiclassaccuracy_2': tensor(0.)}
- Example::
Using custom name via prefix and postfix:
>>> import torch >>> _ = torch.manual_seed(42) >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-") >>> metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc") >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric_pre(preds, target) {'acc-0': tensor(0.5000), 'acc-1': tensor(0.7500), 'acc-2': tensor(0.)} >>> metric_post(preds, target) {'0-acc': tensor(0.5000), '1-acc': tensor(0.7500), '2-acc': tensor(0.)}
- Example::
Providing labels as a list of strings:
>>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper( ... MulticlassAccuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'multiclassaccuracy_horse': tensor(0.3333), 'multiclassaccuracy_fish': tensor(0.6667), 'multiclassaccuracy_dog': tensor(0.)}
- Example::
Classwise can also be used in combination with
MetricCollection
. In this case, everything will be flattened into a single dictionary:>>> from torchmetrics import MetricCollection >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall >>> labels = ["horse", "fish", "dog"] >>> metric = MetricCollection( ... {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels), ... 'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)} ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) {'multiclassaccuracy_horse': tensor(0.), 'multiclassaccuracy_fish': tensor(0.3333), 'multiclassaccuracy_dog': tensor(0.4000), 'multiclassrecall_horse': tensor(0.), 'multiclassrecall_fish': tensor(0.3333), 'multiclassrecall_dog': tensor(0.4000)}
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) >>> metric.update(torch.randint(3, (20,)), torch.randint(3, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values)
Metric Tracker¶
Module Interface¶
- class torchmetrics.wrappers.MetricTracker(metric, maximize=True)[source]¶
A wrapper class that can help keeping track of a metric or metric collection over time.
The wrapper implements the standard
.update()
,.compute()
,.reset()
methods that just calls corresponding method of the currently tracked metric. However, the following additional methods are provided:-
MetricTracker.n_steps
: number of metrics being tracked -MetricTracker.increment()
: initialize a new metric for being tracked -MetricTracker.compute_all()
: get the metric value for all steps -MetricTracker.best_metric()
: returns the best valueOut of the box, this wrapper class fully supports that the base metric being tracked is a single Metric, a MetricCollection or another MetricWrapper wrapped around a metric. However, multiple layers of nesting, such as using a Metric inside a MetricWrapper inside a MetricCollection is not fully supported, especially the .best_metric method that cannot auto compute the best metric and index for such nested structures.
- Parameters:
- Example (single metric):
>>> from torchmetrics.wrappers import MetricTracker >>> from torchmetrics.classification import MulticlassAccuracy >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MulticlassAccuracy(num_classes=10, average='micro')) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randint(10, (100,)), torch.randint(10, (100,)) ... tracker.update(preds, target) ... print(f"current acc={tracker.compute()}") current acc=0.1120000034570694 current acc=0.08799999952316284 current acc=0.12600000202655792 current acc=0.07999999821186066 current acc=0.10199999809265137 >>> best_acc, which_epoch = tracker.best_metric(return_step=True) >>> best_acc 0.1260... >>> which_epoch 2 >>> tracker.compute_all() tensor([0.1120, 0.0880, 0.1260, 0.0800, 0.1020])
- Example (multiple metrics using MetricCollection):
>>> from torchmetrics.wrappers import MetricTracker >>> from torchmetrics import MetricCollection >>> from torchmetrics.regression import MeanSquaredError, ExplainedVariance >>> _ = torch.manual_seed(42) >>> tracker = MetricTracker(MetricCollection([MeanSquaredError(), ExplainedVariance()]), maximize=[False, True]) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... preds, target = torch.randn(100), torch.randn(100) ... tracker.update(preds, target) ... print(f"current stats={tracker.compute()}") current stats={'MeanSquaredError': tensor(1.8218), 'ExplainedVariance': tensor(-0.8969)} current stats={'MeanSquaredError': tensor(2.0268), 'ExplainedVariance': tensor(-1.0206)} current stats={'MeanSquaredError': tensor(1.9491), 'ExplainedVariance': tensor(-0.8298)} current stats={'MeanSquaredError': tensor(1.9800), 'ExplainedVariance': tensor(-0.9199)} current stats={'MeanSquaredError': tensor(2.2481), 'ExplainedVariance': tensor(-1.1622)} >>> from pprint import pprint >>> best_res, which_epoch = tracker.best_metric(return_step=True) >>> pprint(best_res) {'ExplainedVariance': -0.829..., 'MeanSquaredError': 1.821...} >>> which_epoch {'MeanSquaredError': 0, 'ExplainedVariance': 2} >>> pprint(tracker.compute_all()) {'ExplainedVariance': tensor([-0.8969, -1.0206, -0.8298, -0.9199, -1.1622]), 'MeanSquaredError': tensor([1.8218, 2.0268, 1.9491, 1.9800, 2.2481])}
- best_metric(return_step=False)[source]¶
Return the highest metric out of all tracked.
- Parameters:
return_step (
bool
) – IfTrue
will also return the step with the highest metric value.- Return type:
Union
[None
,float
,Tuple
[float
,int
],Tuple
[None
,None
],Dict
[str
,Optional
[float
]],Tuple
[Dict
[str
,Optional
[float
]],Dict
[str
,Optional
[int
]]]]- Returns:
Either a single value or a tuple, depends on the value of
return_step
and the object being tracked.If a single metric is being tracked and
return_step=False
then a single tensor will be returnedIf a single metric is being tracked and
return_step=True
then a 2-element tuple will be returned, where the first value is optimal value and second value is the corresponding optimal stepIf a metric collection is being tracked and
return_step=False
then a single dict will be returned, where keys correspond to the different values of the collection and the values are the optimal metric valueIf a metric collection is being bracked and
return_step=True
then a 2-element tuple will be returned where each is a dict, with keys corresponding to the different values of th collection and the values of the first dict being the optimal values and the values of the second dict being the optimal step
In addtion the value in all cases may be
None
if the underlying metric does have a proper defined way of being optimal or in the case where a nested structure of metrics are being tracked.
- compute_all()[source]¶
Compute the metric value for all tracked metrics.
- Return type:
- Returns:
By default will try stacking the results from all increaments into a single tensor if the tracked base object is a single metric. If a metric collection is provided a dict of stacked tensors will be returned. If the stacking process fails a list of the computed results will be returned.
- Raises:
ValueError – If self.increment have not been called before this method is called.
- increment()[source]¶
Create a new instance of the input metric that will be updated next.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MetricTracker >>> from torchmetrics.classification import BinaryAccuracy >>> tracker = MetricTracker(BinaryAccuracy()) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): ... tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,))) >>> fig_, ax_ = tracker.plot() # plot all epochs
Min / Max¶
Module Interface¶
- class torchmetrics.wrappers.MinMaxMetric(base_metric, **kwargs)[source]¶
Wrapper metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.
The min/max value will be updated each time
.compute
is called.- Parameters:
base_metric (
Metric
) – The metric of which you want to keep track of its maximum and minimum values.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
base_metric` argument is not a subclasses instance of ``torchmetrics.Metric
- Example::
>>> import torch >>> from torchmetrics.wrappers import MinMaxMetric >>> from torchmetrics.classification import BinaryAccuracy >>> from pprint import pprint >>> base_metric = BinaryAccuracy() >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) >>> labels = torch.Tensor([[0, 1], [0, 1]]).long() >>> pprint(minmax_metric(preds_1, labels)) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)} >>> minmax_metric.update(preds_2, labels) >>> pprint(minmax_metric.compute()) {'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}
- compute()[source]¶
Compute the underlying metric as well as max and min values for this metric.
Returns a dictionary that consists of the computed value (
raw
), as well as the minimum (min
) and maximum (max
) values.
- forward(*args, **kwargs)[source]¶
Use the original forward method of the base metric class.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MinMaxMetric >>> from torchmetrics.classification import BinaryAccuracy >>> metric = MinMaxMetric(BinaryAccuracy()) >>> metric.update(torch.randint(2, (20,)), torch.randint(2, (20,))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MinMaxMetric >>> from torchmetrics.classification import BinaryAccuracy >>> metric = MinMaxMetric(BinaryAccuracy()) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,)))) >>> fig_, ax_ = metric.plot(values)
Multi-output Wrapper¶
Module Interface¶
- class torchmetrics.wrappers.MultioutputWrapper(base_metric, num_outputs, output_dim=-1, remove_nans=True, squeeze_outputs=True)[source]¶
Wrap a base metric to enable it to support multiple outputs.
Several torchmetrics metrics, such as
SpearmanCorrCoef
lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn’t support any aggregation across outputs. This means if you setnum_outputs
to 2,.compute()
will return a Tensor of dimension(2, ...)
where...
represents the dimensions the metric returns when not wrapped.In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When
remove_nans
is passed, the class will remove the intersection of NaN containing “rows” upon each update for each output. For example, suppose a user uses MultioutputWrapper to wraptorchmetrics.regression.r2.R2Score
with 2 outputs, one of which occasionally has missing labels for classes likeR2Score
is that this class supports removingNaN
values (parameterremove_nans
) on a per-output basis. Whenremove_nans
is passed the wrapper will remove all rows- Parameters:
base_metric (
Metric
) – Metric being wrapped.num_outputs (
int
) – Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track.output_dim (
int
) – Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as Accuracy where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.remove_nans (
bool
) – Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension(N, ...)
where N represents the length of the batch or dataset being passed in.squeeze_outputs (
bool
) – IfTrue
, will squeeze the 1-item dimensions left afterindex_select
is applied. This is sometimes unnecessary but harmless for metrics such as R2Score but useful for certain classification metrics that can’t handle additional 1-item dimensions.
Example
>>> # Mimic R2Score in `multioutput`, `raw_values` mode: >>> import torch >>> from torchmetrics.wrappers import MultioutputWrapper >>> from torchmetrics.regression import R2Score >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) >>> r2score = MultioutputWrapper(R2Score(), 2) >>> r2score(preds, target) tensor([0.9654, 0.9082])
- forward(*args, **kwargs)[source]¶
Call underlying forward methods and aggregate the results if they’re non-null.
We override this method to ensure that state variables get copied over on the underlying metrics.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MultioutputWrapper >>> from torchmetrics.regression import R2Score >>> metric = MultioutputWrapper(R2Score(), 2) >>> metric.update(torch.randn(20, 2), torch.randn(20, 2)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MultioutputWrapper >>> from torchmetrics.regression import R2Score >>> metric = MultioutputWrapper(R2Score(), 2) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randn(20, 2), torch.randn(20, 2))) >>> fig_, ax_ = metric.plot(values)
Multi-task Wrapper¶
Module Interface¶
- class torchmetrics.wrappers.MultitaskWrapper(task_metrics)[source]¶
Wrapper class for computing different metrics on different tasks in the context of multitask learning.
In multitask learning the different tasks requires different metrics to be evaluated. This wrapper allows for easy evaluation in such cases by supporting multiple predictions and targets through a dictionary. Note that only metrics where the signature of update follows the stardard preds, target is supported.
- Parameters:
task_metrics (
Dict
[str
,Union
[Metric
,MetricCollection
]]) – Dictionary associating each task to a Metric or a MetricCollection. The keys of the dictionary represent the names of the tasks, and the values represent the metrics to use for each task.- Raises:
- Example (with a single metric per class):
>>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> metrics.update(preds, targets) >>> metrics.compute() {'Classification': tensor(0.3333), 'Regression': tensor(0.8333)}
- Example (with several metrics per task):
>>> import torch >>> from torchmetrics import MetricCollection >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError >>> from torchmetrics.classification import BinaryAccuracy, BinaryF1Score >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": MetricCollection(BinaryAccuracy(), BinaryF1Score()), ... "Regression": MetricCollection(MeanSquaredError(), MeanAbsoluteError()) ... }) >>> metrics.update(preds, targets) >>> metrics.compute() {'Classification': {'BinaryAccuracy': tensor(0.3333), 'BinaryF1Score': tensor(0.)}, 'Regression': {'MeanSquaredError': tensor(0.8333), 'MeanAbsoluteError': tensor(0.6667)}}
- forward(task_preds, task_targets)[source]¶
Call underlying forward methods for all tasks and return the result as a dictionary.
- plot(val=None, axes=None)[source]¶
Plot a single or multiple values from the metric.
All tasks’ results are plotted on individual axes.
- Parameters:
val (
Union
[Dict
,Sequence
[Dict
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.axes (
Optional
[Sequence
[Axes
]]) – Sequence of matplotlib axis objects. If provided, will add the plots to the provided axis objects. If not provided, will create them.
- Return type:
- Returns:
Sequence of tuples with Figure and Axes object for each task.
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> metrics.update(preds, targets) >>> value = metrics.compute() >>> fig_, ax_ = metrics.plot(value)
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> values = [] >>> for _ in range(10): ... values.append(metrics(preds, targets)) >>> fig_, ax_ = metrics.plot(values)
Running¶
Module Interface¶
- class torchmetrics.wrappers.Running(base_metric, window=5)[source]¶
Running wrapper for metrics.
Using this wrapper allows for calculating metrics over a running window of values, instead of the whole history of values. This is beneficial when you want to get a better estimate of the metric during training and don’t want to wait for the whole training to finish to get epoch level estimates.
The running window is defined by the window argument. The window is a fixed size and this wrapper will store a duplicate of the underlying metric state for each value in the window. Thus memory usage will increase linearly with window size. Use accordingly. Also note that the running only works with metrics that have the full_state_update set to False.
Importantly, the wrapper does not alter the value of the forward method of the underlying metric. Thus, forward will still return the value on the current batch. To get the running value call compute instead.
- Parameters:
Example
# Single metric >>> from torch import tensor >>> from torchmetrics.wrappers import Running >>> from torchmetrics.aggregation import SumMetric >>> metric = Running(SumMetric(), window=3) >>> for i in range(6): … current_val = metric(tensor([i])) … running_val = metric.compute() … total_val = tensor(sum(list(range(i+1)))) # value we would get from compute without running … print(f”{current_val=}, {running_val=}, {total_val=}”) current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0) current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1) current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3) current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6) current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10) current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)
Example
# Metric collection >>> from torch import tensor >>> from torchmetrics.wrappers import Running >>> from torchmetrics import MetricCollection >>> from torchmetrics.aggregation import SumMetric, MeanMetric >>> # note that running is input to collection, not the other way >>> metric = MetricCollection({“sum”: Running(SumMetric(), 3), “mean”: Running(MeanMetric(), 3)}) >>> for i in range(6): … current_val = metric(tensor([i])) … running_val = metric.compute() … print(f”{current_val=}, {running_val=}”) current_val={‘mean’: tensor(0.), ‘sum’: tensor(0.)}, running_val={‘mean’: tensor(0.), ‘sum’: tensor(0.)} current_val={‘mean’: tensor(1.), ‘sum’: tensor(1.)}, running_val={‘mean’: tensor(0.5000), ‘sum’: tensor(1.)} current_val={‘mean’: tensor(2.), ‘sum’: tensor(2.)}, running_val={‘mean’: tensor(1.), ‘sum’: tensor(3.)} current_val={‘mean’: tensor(3.), ‘sum’: tensor(3.)}, running_val={‘mean’: tensor(2.), ‘sum’: tensor(6.)} current_val={‘mean’: tensor(4.), ‘sum’: tensor(4.)}, running_val={‘mean’: tensor(3.), ‘sum’: tensor(9.)} current_val={‘mean’: tensor(5.), ‘sum’: tensor(5.)}, running_val={‘mean’: tensor(4.), ‘sum’: tensor(12.)}
- forward(*args, **kwargs)[source]¶
Forward input to the underlying metric and save state afterwards.
- Return type:
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.wrappers import Running >>> from torchmetrics.aggregation import SumMetric >>> metric = Running(SumMetric(), 2) >>> metric.update(torch.randn(20, 2)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import Running >>> from torchmetrics.aggregation import SumMetric >>> metric = Running(SumMetric(), 2) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randn(20, 2))) >>> fig_, ax_ = metric.plot(values)
torchmetrics.Metric¶
The base Metric
class is an abstract base class that are used as the building block for all other Module
metrics.
- class torchmetrics.Metric(**kwargs)[source]¶
Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are *
add_state()
*forward()
*reset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten *
update()
*compute()
- Parameters:
kwargs (
Any
) –additional keyword arguments, see Advanced metric settings for more info.
compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
dist_sync_on_step: If metric state should synchronize on
forward()
. Default isFalse
process_group: The process group on which the synchronization is called. Default is the world.
dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls
torch.distributed.all_gather
internally.distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
.sync_on_compute: If metric state should synchronize when
compute
is called. Default isTrue
compute_with_cache: If results from
compute
should be cached. Default isFalse
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]¶
Add metric state variable. Only used by subclasses.
Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if
name
is"my_state"
then its value can be accessed from an instancemetric
asmetric.my_state
. Metric states behave like buffers and parameters ofModule
as they are also updated when.to()
is called. Unlike parameters and buffers, metric states are not by default saved in the modulesstate_dict
.- Parameters:
name (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default (
Union
[list
,Tensor
]) – Default value of the state; can either be aTensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
,"cat"
,"min"
or"max"
we will usetorch.sum
,torch.mean
,torch.cat
,torch.min
andtorch.max`
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
- Return type:
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
Tensor
, the synced value will be a stackedTensor
across the process dimension if the metric state was aTensor
. The originalTensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.- Raises:
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,"min"
,"max"
orNone
.
- abstract compute()[source]¶
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- Return type:
- double()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- float()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- forward(*args, **kwargs)[source]¶
Aggregate and evaluate batch input directly.
Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Parameters:
- Return type:
- Returns:
The output of the
compute
method evaluated on the current batch.- Raises:
TorchMetricsUserError – If the metric is already synced and
forward
is called again.
- half()[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- persistent(mode=False)[source]¶
Change post-init if metric states should be saved to its state_dict.
- Return type:
- set_dtype(dst_type)[source]¶
Transfer all metric state to specific dtype. Special version of standard type method.
- state_dict(destination=None, prefix='', keep_vars=False)[source]¶
Get the current state of metric as an dictionary.
- Parameters:
destination (
Optional
[Dict
[str
,Any
]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, anOrderedDict
will be created and returned.prefix (
str
) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.keep_vars (
bool
) – by default theTensor
returned in the state dict are detached from autograd. If set toTrue
, detaching will not be performed.
- Return type:
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]¶
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters:
dist_sync_fn (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Raises:
TorchMetricsUserError – If the metric is already synced and
sync
is called again.- Return type:
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]¶
Context manager to synchronize states.
This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the syncronized state.
- Parameters:
dist_sync_fn (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync (
bool
) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type:
- type(dst_type)[source]¶
Override default and prevent dtype casting.
Please use
Metric.set_dtype()
instead.- Return type:
- unsync(should_unsync=True)[source]¶
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]¶
Override this method to update the state variables of your metric class.
- Return type:
torchmetrics.utilities.data¶
select_topk¶
- torchmetrics.utilities.data.select_topk(prob_tensor, topk=1, dim=1)[source]¶
Convert a probability tensor to binary by selecting top-k the highest entries.
- Parameters:
- Return type:
- Returns:
A binary tensor of the same shape as the input tensor of type
torch.int32
Example
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) >>> select_topk(x, topk=2) tensor([[0, 1, 1], [1, 1, 0]], dtype=torch.int32)
to_categorical¶
to_onehot¶
- torchmetrics.utilities.data.to_onehot(label_tensor, num_classes=None)[source]¶
Convert a dense label tensor to one-hot format.
- Parameters:
- Return type:
- Returns:
A sparse label tensor with shape [N, C, d1, d2, …]
Example
>>> x = torch.tensor([1, 2, 3]) >>> to_onehot(x) tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
torchmetrics.utilities.exceptions¶
TorchMetricsUserError¶
TorchMetricsUserWarning¶
TorchMetrics Governance¶
This document describes governance processes we follow in developing TorchMetrics.
Persons of Interest¶
Leads¶
Nicki Skafte (skaftenicki)
Jirka Borovec (Borda)
Justus Schock (justusschock)
Core Maintainers¶
Daniel Stancl (stancld)
Luca Di Liello (lucadiliello)
Changsheng Quan (quancs)
Alumni¶
Ananya Harsh Jha (ananyahjha93)
Teddy Koker (teddykoker)
Maxim Grechkin (maximsch2)
Releases¶
We release a new minor version (e.g., 0.5.0) every few months and bugfix releases if needed. The minor versions contain new features, API changes, deprecations, removals, potential backward-incompatible changes and also all previous bugfixes included in any bugfix release. With every release, we publish a changelog where we list additions, removals, changed functionality and fixes.
Project Management and Decision Making¶
The decision what goes into a release is governed by the staff contributors and leaders of TorchMetrics development. Whenever possible, discussion happens publicly on GitHub and includes the whole community. When a consensus is reached, staff and core contributors assign milestones and labels to the issue and/or pull request and start tracking the development. It is possible that priorities change over time.
Commits to the project are exclusively to be added by pull requests on GitHub and anyone in the community is welcome to review them. However, reviews submitted by code owners have higher weight and it is necessary to get the approval of code owners before a pull request can be merged. Additional requirements may apply case by case.
API Evolution¶
TorchMetrics development is driven by research and best practices in a rapidly developing field of AI and machine learning. Change is inevitable and when it happens, the Torchmetric team is committed to minimizing user friction and maximizing ease of transition from one version to the next. We take backward compatibility and reproducibility very seriously.
For API removal, renaming or other forms of backward-incompatible changes, the procedure is:
A deprecation process is initiated at version X, producing warning messages at runtime and in the documentation.
Calls to the deprecated API remain unchanged in their function during the deprecation phase.
One minor versions in the future at version X+1 the breaking change takes effect.
The “X+1” rule is a recommendation and not a strict requirement. Longer deprecation cycles may apply for some cases.
Contributor Covenant Code of Conduct¶
Our Pledge¶
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
Our Standards¶
Examples of behavior that contributes to creating a positive environment include:
Using welcoming and inclusive language
Being respectful of differing viewpoints and experiences
Gracefully accepting constructive criticism
Focusing on what is best for the community
Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
The use of sexualized language or imagery and unwelcome sexual attention or advances
Trolling, insulting/derogatory comments, and personal or political attacks
Public or private harassment
Publishing others’ private information, such as a physical or electronic address, without explicit permission
Other conduct which could reasonably be considered inappropriate in a professional setting
Our Responsibilities¶
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
Scope¶
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
Enforcement¶
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at waf2107@columbia.edu. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership.
Attribution¶
This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq
Contributing¶
Welcome to the Torchmetrics community! We’re building largest collection of native pytorch metrics, with the goal of reducing boilerplate and increasing reproducibility.
Contribution Types¶
We are always looking for help implementing new features or fixing bugs.
Bug Fixes:¶
If you find a bug please submit a github issue.
Make sure the title explains the issue.
Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.
Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data.
Try to fix it or recommend a solution. We highly recommend to use test-driven approach:
Convert your minimal code example to a unit/integration test with assert on expected results.
Start by debugging the issue… You can run just this particular test in your IDE and draft a fix.
Verify that your test case fails on the master branch and only passes with the fix applied.
Submit a PR!
Note, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :]
New Features:¶
Submit a github issue - describe what is the motivation of such feature (adding the use case or an example is helpful).
Let’s discuss to determine the feature scope.
Submit a PR! We recommend test driven approach to adding new features as well:
Write a test for the functionality you want to add.
Write the functional code until the test passes.
Add/update the relevant tests!
This PR is a good example for adding a new metric
Test cases:¶
Want to keep Torchmetrics healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features. One of the core values of torchmetrics is that our users can trust our metric implementation. We can only guarantee this if our metrics are well tested.
Guidelines¶
Developments scripts¶
To build the documentation locally, simply execute the following commands from project root (only for Unix):
make clean
cleans repo from temp/generated filesmake docs
builds documentation under docs/build/htmlmake test
runs all project’s tests with coverage
Original code¶
All added or edited code shall be the own original work of the particular contributor.
If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if
possible also agreed by code’s author. For example - This code is inspired from http://...
.
In case you adding new dependencies, make sure that they are compatible with the actual Torchmetrics license
(ie. dependencies should be at least as permissive as the Torchmetrics license).
Coding Style¶
Use f-strings for output formation (except logging when we stay with lazy
logging.info("Hello %s!", name)
.You can use
pre-commit
to make sure your code style is correct.
Documentation¶
We are using Sphinx with Napoleon extension. Moreover, we set Google style to follow with type convention.
See following short example of a sample function taking one position string and optional
from typing import Optional
def my_func(param_a: int, param_b: Optional[float] = None) -> str:
"""Sample function.
Args:
param_a: first parameter
param_b: second parameter
Return:
sum of both numbers
Example:
Sample doctest example...
>>> my_func(1, 2)
3
.. note:: If you want to add something.
"""
p = param_b if param_b else 0
return str(param_a + p)
When updating the docs make sure to build them first locally and visually inspect the html files (in the browser) for formatting errors. In certain cases, a missing blank line or a wrong indent can lead to a broken layout. Run these commands
make docs
and open docs/build/html/index.html
in your browser.
Notes:
You need to have LaTeX installed for rendering math equations. You can for example install TeXLive by doing one of the following:
on Ubuntu (Linux) run
apt-get install texlive
or otherwise follow the instructions on the TeXLive websiteuse the RTD docker image
with PL used class meta you need to use python 3.7 or higher
When you send a PR the continuous integration will run tests and build the docs.
Testing¶
Local: Testing your work locally will help you speed up the process since it allows you to focus on particular (failing) test-cases. To setup a local development environment, install both local and test dependencies:
python -m pip install -r requirements/test.txt
python -m pip install pre-commit
You can run the full test-case in your terminal via this make script:
make test
# or natively
python -m pytest torchmetrics tests
Note: if your computer does not have multi-GPU nor TPU these tests are skipped.
GitHub Actions: For convenience, you can also use your own GHActions building which will be triggered with each commit. This is useful if you do not test against all required dependency versions.
Changelog¶
All notable changes to this project will be documented in this file.
The format is based on Keep a Changelog, and this project adheres to Semantic Versioning.
Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.
[1.1.2] - 2023-09-11¶
[1.1.2] - Fixed¶
Fixed tie breaking in ndcg metric (#2031)
Fixed bug in
BootStrapper
when very few samples were evaluated that could lead to crash (#2052)Fixed bug when creating multiple plots that lead to not all plots being shown (#2060)
Fixed performance issues in
RecallAtFixedPrecision
for large batch sizes (#2042)Fixed bug related to
MetricCollection
used with custom metrics haveprefix
/postfix
attributes (#2070)
[1.1.1] - 2023-08-29¶
[1.1.1] - Added¶
Added
average
argument toMeanAveragePrecision
(#2018)
[1.1.1] - Fixed¶
Fixed bug in
PearsonCorrCoef
is updated on single samples at a time (#2019)Fixed support for pixelwise MSE (#2017)
Fixed bug in
MetricCollection
when used with multiple metrics that return dicts with same keys (#2027)Fixed bug in detection intersection metrics when
class_metrics=True
resulting in wrong values (#1924)Fixed missing attributes
higher_is_better
,is_differentiable
for some metrics (#2028)
[1.1.0] - 2023-08-22¶
[1.1.0] - Added¶
Added source aggregated signal-to-distortion ratio (SA-SDR) metric (#1882
Added
VisualInformationFidelity
to image package (#1830)Added
EditDistance
to text package (#1906)Added
top_k
argument toRetrievalMRR
in retrieval package (#1961)Added support for evaluating
"segm"
and"bbox"
detection inMeanAveragePrecision
at the same time (#1928)Added
PerceptualPathLength
to image package (#1939)Added support for multioutput evaluation in
MeanSquaredError
(#1937)Added argument
extended_summary
toMeanAveragePrecision
such that precision, recall, iou can be easily returned (#1983)Added warning to
ClipScore
if long captions are detected and truncate (#2001)Added
CLIPImageQualityAssessment
to multimodal package (#1931)Added new property
metric_state
to all metrics for users to investigate currently stored tensors in memory (#2006)
[1.0.3] - 2023-08-08¶
[1.0.3] - Added¶
Added warning to
MeanAveragePrecision
if too many detections are observed (#1978)
[1.0.3] - Fixed¶
[1.0.2] - 2023-08-02¶
[1.0.2] - Added¶
Added warning to
PearsonCorrCoeff
if input has a very small variance for its given dtype (#1926)
[1.0.2] - Changed¶
Changed all non-task specific classification metrics to be true subtypes of
Metric
(#1963)
[1.0.2] - Fixed¶
Fixed bug in
CalibrationError
where calculations for double precision input was performed in float precision (#1919)Fixed bug related to the
prefix/postfix
arguments inMetricCollection
andClasswiseWrapper
being duplicated (#1918)Fixed missing AUC score when plotting classification metrics that support the
score
argument (#1948)
[1.0.1] - 2023-07-13¶
[1.0.1] - Fixed¶
Fixes corner case when using
MetricCollection
together with aggregation metrics (#1896)Fixed the use of
max_fpr
inAUROC
metric when only one class is present (#1895)Fixed bug related to empty predictions for
IntersectionOverUnion
metric (#1892)Fixed bug related to
MeanMetric
and broadcasting of weights when Nans are present (#1898)Fixed bug related to expected input format of pycoco in
MeanAveragePrecision
(#1913)
[1.0.0] - 2023-07-04¶
[1.0.0] - Added¶
Added
prefix
andpostfix
arguments toClasswiseWrapper
(#1866)Added speech-to-reverberation modulation energy ratio (SRMR) metric (#1792, #1872)
Added new global arg
compute_with_cache
to control caching behaviour aftercompute
method (#1754)Added
ComplexScaleInvariantSignalNoiseRatio
for audio package (#1785)Added
Running
wrapper for calculate running statistics (#1752)Added
RelativeAverageSpectralError
andRootMeanSquaredErrorUsingSlidingWindow
to image package (#816)Added support for
SpecificityAtSensitivity
Metric (#1432)Added support for plotting of metrics through
.plot()
method ( #1328, #1481, #1480, #1490, #1581, #1585, #1593, #1600, #1605, #1610, #1609, #1621, #1624, #1623, #1638, #1631, #1650, #1639, #1660, #1682, #1786, )Added support for plotting of audio metrics through
.plot()
method (#1434)Added
classes
to output fromMAP
metric (#1419)Added Binary group fairness metrics to classification package (#1404)
Added
MinkowskiDistance
to regression package (#1362)Added
pairwise_minkowski_distance
to pairwise package (#1362)Added
PSNRB
metric (#1421)Added
ClassificationTask
Enum and use in metrics (#1479)Added
ignore_index
option toexact_match
metric (#1540)Add parameter
top_k
toRetrievalMAP
(#1501)Added support for deterministic evaluation on GPU for metrics that uses
torch.cumsum
operator (#1499)Added support for plotting of aggregation metrics through
.plot()
method (#1485)Added support for python 3.11 (#1612)
Added support for auto clamping of input for metrics that uses the
data_range
([#1606](argument https://github.com/Lightning-AI/metrics/pull/1606))Added
ModifiedPanopticQuality
metric to detection package (#1627)Added
PrecisionAtFixedRecall
metric to classification package (#1683)Added multiple metrics to detection package (#1284)
IntersectionOverUnion
GeneralizedIntersectionOverUnion
CompleteIntersectionOverUnion
DistanceIntersectionOverUnion
Added
MultitaskWrapper
to wrapper package (#1762)Added
RelativeSquaredError
metric to regression package (#1765)Added
MemorizationInformedFrechetInceptionDistance
metric to image package (#1580)
[1.0.0] - Changed¶
Changed
permutation_invariant_training
to allow using a'permutation-wise'
metric function (#1794)Changed
update_count
andupdate_called
from private to public methods (#1370)Raise exception for invalid kwargs in Metric base class (#1427)
Extend
EnumStr
raisingValueError
for invalid value (#1479)Improve speed and memory consumption of binned
PrecisionRecallCurve
with large number of samples (#1493)Changed
__iter__
method from raisingNotImplementedError
toTypeError
by setting toNone
(#1538)FID
metric will now raise an error if too few samples are provided (#1655)Allowed FID with
torch.float64
(#1628)Changed
LPIPS
implementation to no more rely on third-party package (#1575)Changed FID matrix square root calculation from
scipy
totorch
(#1708)Changed calculation in
PearsonCorrCoeff
to be more robust in certain cases (#1729)Changed
MeanAveragePrecision
topycocotools
backend (#1832)
[1.0.0] - Deprecated¶
[1.0.0] - Removed¶
Support for python 3.7 (#1640)
[1.0.0] - Fixed¶
Fixed support in
MetricTracker
forMultioutputWrapper
and nested structures (#1608)Fixed restrictive check in
PearsonCorrCoef
(#1649)Fixed integration with
jsonargparse
andLightningCLI
(#1651)Fixed corner case in calibration error for zero confidence input (#1648)
Fix precision-recall curve based computations for float target (#1642)
Fixed missing kwarg squeeze in
MultiOutputWrapper
(#1675)Fixed padding removal for 3d input in
MSSSIM
(#1674)Fixed
max_det_threshold
in MAP detection (#1712)Fixed states being saved in metrics that use
register_buffer
(#1728)Fixed states not being correctly synced and device transfered in
MeanAveragePrecision
foriou_type="segm"
(#1763)Fixed use of
prefix
andpostfix
in nestedMetricCollection
(#1773)Fixed
ax
plotting logging in `MetricCollection (#1783)Fixed lookup for punkt sources being downloaded in
RougeScore
(#1789)Fixed integration with lightning for
CompositionalMetric
(#1761)Fixed several bugs in
SpectralDistortionIndex
metric (#1808)Fixed bug for corner cases in
MatthewsCorrCoef
( #1812, #1863 )Fixed support for half precision in
PearsonCorrCoef
(#1819)Fixed number of bugs related to
average="macro"
in classification metrics (#1821)Fixed off-by-one issue when
ignore_index = num_classes + 1
in Multiclass-jaccard (#1860)
[0.11.4] - 2023-03-10¶
[0.11.4] - Fixed¶
Fixed evaluation of
R2Score
with near constant target (#1576)Fixed dtype conversion when metric is submodule (#1583)
Fixed bug related to
top_k>1
andignore_index!=None
inStatScores
based metrics (#1589)Fixed corner case for
PearsonCorrCoef
when running in ddp mode but only on single device (#1587)Fixed overflow error for specific cases in
MAP
when big areas are calculated (#1607)
[0.11.3] - 2023-02-28¶
[0.11.3] - Fixed¶
[0.11.2] - 2023-02-21¶
[0.11.2] - Fixed¶
[0.11.1] - 2023-01-30¶
[0.11.1] - Fixed¶
Fixed type checking on the
maximize
parameter at the initialization ofMetricTracker
(#1428)Fixed mixed precision autocast for
SSIM
metric (#1454)Fixed checking for
nltk.punkt
inRougeScore
if a machine is not online (#1456)Fixed wrongly reset method in
MultioutputWrapper
(#1460)Fixed dtype checking in
PrecisionRecallCurve
fortarget
tensor (#1457)
[0.11.0] - 2022-11-30¶
[0.11.0] - Added¶
Added
MulticlassExactMatch
to classification metrics (#1343)Added
TotalVariation
to image package (#978)Added
CLIPScore
to new multimodal package (#1314)Added regression metrics:
Added new nominal metrics:
Added option to pass
distributed_available_fn
to metrics to allow checks for custom communication backend for makingdist_sync_fn
actually useful (#1301)Added
normalize
argument toInception
,FID
,KID
metrics (#1246)
[0.11.0] - Changed¶
[0.11.0] - Removed¶
[0.11.0] - Fixed¶
Fixed precision bug in
pairwise_euclidean_distance
(#1352)
[0.10.3] - 2022-11-16¶
[0.10.3] - Fixed¶
[0.10.2] - 2022-10-31¶
[0.10.2] - Changed¶
Changed in-place operation to out-of-place operation in
pairwise_cosine_similarity
(#1288)
[0.10.2] - Fixed¶
Fixed high memory usage for certain classification metrics when
average='micro'
(#1286)Fixed precision problems when
structural_similarity_index_measure
was used with autocast (#1291)Fixed slow performance for confusion matrix based metrics (#1302)
Fixed restrictive dtype checking in
spearman_corrcoef
when used with autocast (#1303)
[0.10.1] - 2022-10-21¶
[0.10.1] - Fixed¶
[0.10.0] - 2022-10-04¶
[0.10.0] - Added¶
Added a new NLP metric
InfoLM
(#915)Added
Perplexity
metric (#922)Added
ConcordanceCorrCoef
metric to regression package (#1201)Added argument
normalize
toLPIPS
metric (#1216)Added support for multiprocessing of batches in
PESQ
metric (#1227)Added support for multioutput in
PearsonCorrCoef
andSpearmanCorrCoef
(#1200)
[0.10.0] - Changed¶
Classification refactor ( #1054, #1143, #1145, #1151, #1159, #1163, #1167, #1175, #1189, #1197, #1215, #1195 )
Changed update in
FID
metric to be done in online fashion to save memory (#1199)Improved performance of retrieval metrics (#1242)
Changed
SSIM
andMSSSIM
update to be online to reduce memory usage (#1231)
[0.10.0] - Deprecated¶
Deprecated
BinnedAveragePrecision
,BinnedPrecisionRecallCurve
,BinnedRecallAtFixedPrecision
(#1163)BinnedAveragePrecision
-> useAveragePrecision
withthresholds
argBinnedPrecisionRecallCurve
-> useAveragePrecisionRecallCurve
withthresholds
argBinnedRecallAtFixedPrecision
-> useRecallAtFixedPrecision
withthresholds
arg
Renamed and refactored
LabelRankingAveragePrecision
,LabelRankingLoss
andCoverageError
(#1167)LabelRankingAveragePrecision
->MultilabelRankingAveragePrecision
LabelRankingLoss
->MultilabelRankingLoss
CoverageError
->MultilabelCoverageError
Deprecated
KLDivergence
andAUC
from classification package (#1189)KLDivergence
moved toregression
packageInstead of
AUC
usetorchmetrics.utils.compute.auc
[0.10.0] - Fixed¶
[0.9.3] - 2022-08-22¶
[0.9.3] - Added¶
Added global option
sync_on_compute
to disable automatic synchronization whencompute
is called (#1107)
[0.9.3] - Fixed¶
[0.9.2] - 2022-06-29¶
[0.9.2] - Fixed¶
Fixed mAP calculation for areas with 0 predictions (#1080)
Fixed bug where avg precision state and auroc state was not merge when using MetricCollections (#1086)
Skip box conversion if no boxes are present in
MeanAveragePrecision
(#1097)Fixed inconsistency in docs and code when setting
average="none"
inAvaragePrecision
metric (#1116)
[0.9.1] - 2022-06-08¶
[0.9.1] - Added¶
[0.9.1] - Fixed¶
[0.9.0] - 2022-05-30¶
[0.9.0] - Added¶
Added
RetrievalPrecisionRecallCurve
andRetrievalRecallAtFixedPrecision
to retrieval package (#951)Added class property
full_state_update
that determinesforward
should callupdate
once or twice ( #984, #1033)Added support for nested metric collections (#1003)
Added
Dice
to classification package (#1021)Added support to segmentation type
segm
as IOU for mean average precision (#822)
[0.9.0] - Changed¶
Renamed
reduction
argument toaverage
in Jaccard score and added additional options (#874)
[0.9.0] - Removed¶
[0.9.0] - Fixed¶
Fixed non-empty state dict for a few metrics (#1012)
Fixed bug when comparing states while finding compute groups (#1022)
Fixed
torch.double
support in stat score metrics (#1023)Fixed
FID
calculation for non-equal size real and fake input (#1028)Fixed case where
KLDivergence
could outputNan
(#1030)Fixed deterministic for PyTorch<1.8 (#1035)
Fixed default value for
mdmc_average
inAccuracy
(#1036)Fixed missing copy of property when using compute groups in
MetricCollection
(#1052)
[0.8.2] - 2022-05-06¶
[0.8.2] - Fixed¶
[0.8.1] - 2022-04-27¶
[0.8.1] - Changed¶
Reimplemented the
signal_distortion_ratio
metric, which removed the absolute requirement offast-bss-eval
(#964)
[0.8.1] - Fixed¶
[0.8.0] - 2022-04-14¶
[0.8.0] - Added¶
Added
WeightedMeanAbsolutePercentageError
to regression package (#948)Added new classification metrics:
Added new image metric:
Added support for
MetricCollection
inMetricTracker
(#718)Added support for 3D image and uniform kernel in
StructuralSimilarityIndexMeasure
(#818)Added smart update of
MetricCollection
(#709)Added
ClasswiseWrapper
for better logging of classification metrics with multiple output values (#832)Added
**kwargs
argument for passing additional arguments to base class (#833)Added negative
ignore_index
for the Accuracy metric (#362)Added
adaptive_k
for theRetrievalPrecision
metric (#910)Added
reset_real_features
argument image quality assessment metrics (#722)Added new keyword argument
compute_on_cpu
to all metrics (#867)
[0.8.0] - Changed¶
Made
num_classes
injaccard_index
a required argument (#853, #914)Added normalizer, tokenizer to ROUGE metric (#838)
Improved shape checking of
permutation_invariant_training
(#864)Allowed reduction
None
(#891)MetricTracker.best_metric
will now give a warning when computing on metric that do not have a best (#913)
[0.8.0] - Deprecated¶
[0.8.0] - Removed¶
Removed support for versions of Pytorch-Lightning lower than v1.5 (#788)
Removed deprecated functions, and warnings in Text (#773)
WER
andfunctional.wer
Removed deprecated functions and warnings in Image (#796)
SSIM
andfunctional.ssim
PSNR
andfunctional.psnr
Removed deprecated functions, and warnings in classification and regression (#806)
FBeta
andfunctional.fbeta
F1
andfunctional.f1
Hinge
andfunctional.hinge
IoU
andfunctional.iou
MatthewsCorrcoef
PearsonCorrcoef
SpearmanCorrcoef
Removed deprecated functions, and warnings in detection and pairwise (#804)
MAP
andfunctional.pairwise.manhatten
Removed deprecated functions, and warnings in Audio (#805)
PESQ
andfunctional.audio.pesq
PIT
andfunctional.audio.pit
SDR
andfunctional.audio.sdr
andfunctional.audio.si_sdr
SNR
andfunctional.audio.snr
andfunctional.audio.si_snr
STOI
andfunctional.audio.stoi
Removed unused
get_num_classes
fromtorchmetrics.utilities.data
(#914)
[0.8.0] - Fixed¶
[0.7.3] - 2022-03-23¶
[0.7.3] - Fixed¶
Fixed unsafe log operation in
TweedieDeviace
for power=1 (#847)Fixed bug in MAP metric related to either no ground truth or no predictions (#884)
Fixed
ConfusionMatrix
,AUROC
andAveragePrecision
on GPU when running in deterministic mode (#900)Fixed NaN or Inf results returned by
signal_distortion_ratio
(#899)Fixed memory leak when using
update
method with tensor whererequires_grad=True
(#902)
[0.7.2] - 2022-02-10¶
[0.7.2] - Fixed¶
Minor patches in JOSS paper.
[0.7.1] - 2022-02-03¶
[0.7.1] - Changed¶
[0.7.1] - Fixed¶
[0.7.0] - 2022-01-17¶
[0.7.0] - Added¶
Added NLP metrics:
Added
MultiScaleSSIM
into image metrics (#679)Added Signal to Distortion Ratio (
SDR
) to audio package (#565)Added
MinMaxMetric
to wrappers (#556)Added
ignore_index
to retrieval metrics (#676)Added support for multi references in
ROUGEScore
(#680)Added a default VSCode devcontainer configuration (#621)
[0.7.0] - Changed¶
Scalar metrics will now consistently have additional dimensions squeezed (#622)
Metrics having third party dependencies removed from global import (#463)
Untokenized for
BLEUScore
input stay consistent with all the other text metrics (#640)Arguments reordered for
TER
,BLEUScore
,SacreBLEUScore
,CHRFScore
now expect input order as predictions first and target second (#696)Changed dtype of metric state from
torch.float
totorch.long
inConfusionMatrix
to accommodate larger values (#715)Unify
preds
,target
input argument’s naming across all text metrics (#723, #727)bert
,bleu
,chrf
,sacre_bleu
,wip
,wil
,cer
,ter
,wer
,mer
,rouge
,squad
[0.7.0] - Deprecated¶
Renamed IoU -> Jaccard Index (#662)
Renamed text WER metric (#714)
functional.wer
->functional.word_error_rate
WER
->WordErrorRate
Renamed correlation coefficient classes: (#710)
MatthewsCorrcoef
->MatthewsCorrCoef
PearsonCorrcoef
->PearsonCorrCoef
SpearmanCorrcoef
->SpearmanCorrCoef
Renamed audio STOI metric: (#753, #758)
audio.STOI
toaudio.ShortTimeObjectiveIntelligibility
functional.audio.stoi
tofunctional.audio.short_time_objective_intelligibility
Renamed audio PESQ metrics: (#751)
functional.audio.pesq
->functional.audio.perceptual_evaluation_speech_quality
audio.PESQ
->audio.PerceptualEvaluationSpeechQuality
Renamed audio SDR metrics: (#711)
functional.sdr
->functional.signal_distortion_ratio
functional.si_sdr
->functional.scale_invariant_signal_distortion_ratio
SDR
->SignalDistortionRatio
SI_SDR
->ScaleInvariantSignalDistortionRatio
Renamed audio SNR metrics: (#712)
functional.snr
->functional.signal_distortion_ratio
functional.si_snr
->functional.scale_invariant_signal_noise_ratio
SNR
->SignalNoiseRatio
SI_SNR
->ScaleInvariantSignalNoiseRatio
Renamed F-score metrics: (#731, #740)
functional.f1
->functional.f1_score
F1
->F1Score
functional.fbeta
->functional.fbeta_score
FBeta
->FBetaScore
Renamed Hinge metric: (#734)
functional.hinge
->functional.hinge_loss
Hinge
->HingeLoss
Renamed image PSNR metrics (#732)
functional.psnr
->functional.peak_signal_noise_ratio
PSNR
->PeakSignalNoiseRatio
Renamed image PIT metric: (#737)
functional.pit
->functional.permutation_invariant_training
PIT
->PermutationInvariantTraining
Renamed image SSIM metric: (#747)
functional.ssim
->functional.scale_invariant_signal_noise_ratio
SSIM
->StructuralSimilarityIndexMeasure
Renamed detection
MAP
toMeanAveragePrecision
metric (#754)Renamed Fidelity & LPIPS image metric: (#752)
image.FID
->image.FrechetInceptionDistance
image.KID
->image.KernelInceptionDistance
image.LPIPS
->image.LearnedPerceptualImagePatchSimilarity
[0.7.0] - Removed¶
[0.7.0] - Fixed¶
Fixed MetricCollection kwargs filtering when no
kwargs
are present in update signature (#707)
[0.6.2] - 2021-12-15¶
[0.6.2] - Fixed¶
[0.6.1] - 2021-12-06¶
[0.6.1] - Changed¶
[0.6.1] - Fixed¶
[0.6.0] - 2021-10-28¶
[0.6.0] - Added¶
Added audio metrics:
Added Information retrieval metrics:
Added NLP metrics:
Added other metrics:
Added
MAP
(mean average precision) metric to new detection package (#467)Added support for float targets in
nDCG
metric (#437)Added
average
argument toAveragePrecision
metric for reducing multi-label and multi-class problems (#477)Added
MultioutputWrapper
(#510)Added metric sweeping:
Added simple aggregation metrics:
SumMetric
,MeanMetric
,CatMetric
,MinMetric
,MaxMetric
(#506)Added pairwise submodule with metrics (#553)
pairwise_cosine_similarity
pairwise_euclidean_distance
pairwise_linear_similarity
pairwise_manhatten_distance
[0.6.0] - Changed¶
AveragePrecision
will now as default output themacro
average for multilabel and multiclass problems (#477)half
,double
,float
will no longer change the dtype of the metric states. Usemetric.set_dtype
instead (#493)Renamed
AverageMeter
toMeanMetric
(#506)Changed
is_differentiable
from property to a constant attribute (#551)ROC
andAUROC
will no longer throw an error when either the positive or negative class is missing. Instead return 0 score and give a warning
[0.6.0] - Deprecated¶
Deprecated
functional.self_supervised.embedding_similarity
in favour of new pairwise submodule
[0.6.0] - Removed¶
Removed
dtype
property (#493)
[0.6.0] - Fixed¶
Fixed bug in
F1
withaverage='macro'
andignore_index!=None
(#495)Fixed bug in
pit
by using the returned first result to initialize device and type (#533)Fixed
SSIM
metric using too much memory (#539)Fixed bug where
device
property was not properly update when metric was a child of a module (#542)
[0.5.1] - 2021-08-30¶
[0.5.1] - Added¶
[0.5.1] - Changed¶
Added support for float targets in
nDCG
metric (#437)
[0.5.1] - Removed¶
[0.5.1] - Fixed¶
Fixed ranking of samples in
SpearmanCorrCoef
metric (#448)Fixed bug where compositional metrics where unable to sync because of type mismatch (#454)
Fixed metric hashing (#478)
Fixed
BootStrapper
metrics not working on GPU (#462)Fixed the semantic ordering of kernel height and width in
SSIM
metric (#474)
[0.5.0] - 2021-08-09¶
[0.5.0] - Added¶
Added Text-related (NLP) metrics:
Added
MetricTracker
wrapper metric for keeping track of the same metric over multiple epochs (#238)Added other metrics:
Added support in
nDCG
metric for target with values larger than 1 (#349)Added support for negative targets in
nDCG
metric (#378)Added
None
as reduction option inCosineSimilarity
metric (#400)Allowed passing labels in (n_samples, n_classes) to
AveragePrecision
(#386)
[0.5.0] - Changed¶
Moved
psnr
andssim
fromfunctional.regression.*
tofunctional.image.*
(#382)Moved
image_gradient
fromfunctional.image_gradients
tofunctional.image.gradients
(#381)Moved
R2Score
fromregression.r2score
toregression.r2
(#371)Pearson metric now only store 6 statistics instead of all predictions and targets (#380)
Use
torch.argmax
instead oftorch.topk
whenk=1
for better performance (#419)Moved check for number of samples in R2 score to support single sample updating (#426)
[0.5.0] - Deprecated¶
[0.5.0] - Removed¶
Removed restriction that
threshold
has to be in (0,1) range to support logit input ( #351 #401)Removed restriction that
preds
could not be bigger thannum_classes
to support logit input (#357)Removed module
regression.psnr
andregression.ssim
(#382):Removed (#379):
function
functional.mean_relative_error
num_thresholds
argument inBinnedPrecisionRecallCurve
[0.5.0] - Fixed¶
Fixed bug where classification metrics with
average='macro'
would lead to wrong result if a class was missing (#303)Fixed
weighted
,multi-class
AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 (#376)Fixed that
_forward_cache
and_computed
attributes are also moved to the correct device if metric is moved (#413)Fixed calculation in
IoU
metric when usingignore_index
argument (#328)
[0.4.1] - 2021-07-05¶
[0.4.1] - Changed¶
[0.4.1] - Fixed¶
Fixed DDP by
is_sync
logic toMetric
(#339)
[0.4.0] - 2021-06-29¶
[0.4.0] - Added¶
Added Image-related metrics:
Added Audio metrics: SNR, SI_SDR, SI_SNR (#292)
Added other metrics:
Added
add_metrics
method toMetricCollection
for adding additional metrics after initialization (#221)Added pre-gather reduction in the case of
dist_reduce_fx="cat"
to reduce communication cost (#217)Added better error message for
AUROC
whennum_classes
is not provided for multiclass input (#244)Added support for unnormalized scores (e.g. logits) in
Accuracy
,Precision
,Recall
,FBeta
,F1
,StatScore
,Hamming
,ConfusionMatrix
metrics (#200)Added
squared
argument toMeanSquaredError
for computingRMSE
(#249)Added
is_differentiable
property toConfusionMatrix
,F1
,FBeta
,Hamming
,Hinge
,IOU
,MatthewsCorrcoef
,Precision
,Recall
,PrecisionRecallCurve
,ROC
,StatScores
(#253)Added
sync
andsync_context
methods for manually controlling when metric states are synced (#302)
[0.4.0] - Changed¶
Forward cache is reset when
reset
method is called (#260)Improved per-class metric handling for imbalanced datasets for
precision
,recall
,precision_recall
,fbeta
,f1
,accuracy
, andspecificity
(#204)Decorated
torch.jit.unused
toMetricCollection
forward (#307)Renamed
thresholds
argument to binned metrics for manually controlling the thresholds (#322)
[0.4.0] - Deprecated¶
[0.4.0] - Removed¶
Removed argument
is_multiclass
(#319)
[0.4.0] - Fixed¶
[0.3.2] - 2021-05-10¶
[0.3.2] - Added¶
[0.3.2] - Changed¶
[0.3.2] - Removed¶
Removed
numpy
as direct dependency (#212)
[0.3.2] - Fixed¶
Fixed auc calculation and add tests (#197)
Fixed loading persisted metric states using
load_state_dict()
(#202)Fixed
PSNR
not working withDDP
(#214)Fixed metric calculation with unequal batch sizes (#220)
Fixed metric concatenation for list states for zero-dim input (#229)
Fixed numerical instability in
AUROC
metric for large input (#230)
[0.3.1] - 2021-04-21¶
[0.3.0] - 2021-04-20¶
[0.3.0] - Added¶
Added
BootStrapper
to easily calculate confidence intervals for metrics (#101)Added Binned metrics (#128)
Added metrics for Information Retrieval ((PL^5032)):
Added other metrics:
Added
average='micro'
as an option in AUROC for multilabel problems (#110)Added multilabel support to
ROC
metric (#114)Added
AverageMeter
for ad-hoc averages of values (#138)Added
prefix
argument toMetricCollection
(#70)Added
__getitem__
as metric arithmetic operation (#142)Added property
is_differentiable
to metrics and test for differentiability (#154)Added support for
average
,ignore_index
andmdmc_average
inAccuracy
metric (#166)Added
postfix
arg toMetricCollection
(#188)
[0.3.0] - Changed¶
Changed
ExplainedVariance
from storing all preds/targets to tracking 5 statistics (#68)Changed behaviour of
confusionmatrix
for multilabel data to better matchmultilabel_confusion_matrix
from sklearn (#134)Updated FBeta arguments (#111)
Changed
reset
method to usedetach.clone()
instead ofdeepcopy
when resetting to default (#163)Metrics passed as dict to
MetricCollection
will now always be in deterministic order (#173)Allowed
MetricCollection
pass metrics as arguments (#176)
[0.3.0] - Deprecated¶
Rename argument
is_multiclass
->multiclass
(#162)
[0.3.0] - Removed¶
Prune remaining deprecated (#92)
[0.3.0] - Fixed¶
[0.2.0] - 2021-03-12¶
[0.2.0] - Changed¶
[0.2.0] - Removed¶
[0.1.0] - 2021-02-22¶
Added
Accuracy
metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using thetop_k
parameter (PL^4838)Added
Accuracy
metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with thesubset_accuracy
parameter (PL^4838)Added
HammingDistance
metric to compute the hamming distance (loss) (PL^4838)Added
StatScores
metric to compute the number of true positives, false positives, true negatives and false negatives (PL^4839)Added
R2Score
metric (PL^5241)Added
MetricCollection
(PL^4318)Added
.clone()
method to metrics (PL^4318)Added
IoU
class interface (PL^4704)The
Recall
andPrecision
metrics (and their functional counterpartsrecall
andprecision
) can now be generalized to Recall@K and Precision@K with the use oftop_k
parameter (PL^4842)Added compositional metrics (PL^5464)
Added AUC/AUROC class interface (PL^5479)
Added
QuantizationAwareTraining
callback (PL^5706)Added
ConfusionMatrix
class interface (PL^4348)Added multiclass AUROC metric (PL^4236)
Added
PrecisionRecallCurve, ROC, AveragePrecision
class metric (PL^4549)Classification metrics overhaul (PL^4837)
Added
F1
class metric (PL^4656)Added metrics aggregation in Horovod and fixed early stopping (PL^3775)
Added
persistent(mode)
method to metrics, to enable and disable metric states being added tostate_dict
(PL^4482)Added unification of regression metrics (PL^4166)
Added persistent flag to
Metric.add_state
(PL^4195)Added classification metrics (PL^4043)
Added EMB similarity (PL^3349)
Added SSIM metrics (PL^2671)
Added BLEU metrics (PL^2535)