TorchMetrics documentation

TorchMetrics is a collection of Machine learning metrics for distributed, scalable PyTorch models and an easy-to-use API to create custom metrics. It offers the following benefits:

  • Optimized for distributed-training

  • A standardized interface to increase reproducibility

  • Reduces Boilerplate

  • Distributed-training compatible

  • Rigorously tested

  • Automatic accumulation over batches

  • Automatic synchronization between multiple devices

You can use TorchMetrics in any PyTorch model, or with in PyTorch Lightning to enjoy additional features:

  • This means that your data will always be placed on the same device as your metrics.

  • Native support for logging metrics in Lightning to reduce even more boilerplate.

Using TorchMetrics

Module metrics

import torch
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

Module metric usage remains the same when using multiple GPUs or multiple nodes.

Functional metrics

import torch
import torchmetrics

# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)

Implementing a metric

class MyAccuracy(Metric):
    def __init__(self, dist_sync_on_step=False):
        # call `self.add_state`for every internal state that is needed for the metrics computations
        # dist_reduce_fx indicates the function that should be used to reduce
        # state from multiple processes

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        # update metric states
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target) += target.numel()

    def compute(self):
        # compute final result
        return self.correct.float() /

More reading

Metrics API references

Indices and tables

© Copyright Copyright (c) 2020-2021, PyTorchLightning et al... Revision caf7b0bc.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: v0.5.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.