Shortcuts

Quick Start

TorchMetrics is a collection of 25+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:

  • A standardized interface to increase reproducability

  • Reduces Boilerplate

  • Distrubuted-training compatible

  • Rigorously tested

  • Automatic accumulation over batches

  • Automatic synchronization between multiple devices

You can use TorchMetrics in any PyTorch model, or with in PyTorch Lightning to enjoy additional features:

  • This means that your data will always be placed on the same device as your metrics.

  • Native support for logging metrics in Lightning to reduce even more boilerplate.

Install

You can install TorchMetrics using pip or conda:

pip install torchmetrics

Using TorchMetrics

Functional metrics

Similar to torch.nn, most metrics have both a class-based and a functional version. The functional versions implement the basic operations required for computing each metric. They are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor. The code-snippet below shows a simple example for calculating the accuracy using the functional interface:

import torch
# import our library
import torchmetrics

# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)

Module metrics

Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath. The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:

  • Accumulation of multiple batches

  • Automatic synchronization between multiple devices

  • Metric arithmetic

The code below shows how to use the class-based interface:

import torch
# import our library
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Reseting internal state such that metric ready for new data
metric.reset()

Implementing your own metric

Implementing your own metric is as easy as subclassing an torch.nn.Module. Simply, subclass Metric and do the following:

  1. Implement __init__ where you call self.add_state for every internal state that is needed for the metrics computations

  2. Implement update method, where all logic that is necessary for updating metric states go

  3. Implement compute method, where the final metric computations happens

For practical examples and more info about implementing a metric, please see this page.

Read the Docs v: stable
Versions
latest
stable
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.