Implementing a Metric

To implement your own custom metric, subclass the base Metric class and implement the following methods:

  • __init__(): Each state variable should be called using self.add_state(...).

  • update(): Any code needed to update the state given any inputs to the metric.

  • compute(): Computes a final value from the state of the metric.

All you need to do is call add_state correctly to implement a custom metric with DDP. reset() is called on metric state variables added using add_state().

To see how metric states are synchronized across distributed processes, refer to add_state() docs from the base Metric class.

Example implementation:

from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self, dist_sync_on_step=False):

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target) += target.numel()

    def compute(self):
        return self.correct.float() /

Internal implementation details

This section briefly describes how metrics work internally. We encourage looking at the source code for more info. Internally, Lightning wraps the user defined update() and compute() method. We do this to automatically synchronize and reduce metric states across multiple devices. More precisely, calling update() does the following internally:

  1. Clears computed cache.

  2. Calls user-defined update().

Similarly, calling compute() does the following internally:

  1. Syncs metric states between processes.

  2. Reduce gathered metric states.

  3. Calls the user defined compute() method on the gathered metric states.

  4. Cache computed result.

From a user’s standpoint this has one important side-effect: computed results are cached. This means that no matter how many times compute is called after one and another, it will continue to return the same result. The cache is first emptied on the next call to update.

forward serves the dual purpose of both returning the metric on the current data and updating the internal metric state for accumulating over multiple batches. The forward() method achieves this by combining calls to update and compute in the following way (assuming metric is initialized with compute_on_step=True):

  1. Calls update() to update the global metric state (for accumulation over multiple batches)

  2. Caches the global state.

  3. Calls reset() to clear global metric state.

  4. Calls update() to update local metric state.

  5. Calls compute() to calculate metric for current batch.

  6. Restores the global state.

This procedure has the consequence of calling the user defined update twice during a single forward call (one to update global statistics and one for getting the batch statistics).

class torchmetrics.Metric(compute_on_step=True, dist_sync_on_step=False, process_group=None, dist_sync_fn=None)[source]

Base class for all metrics present in the Metrics API.

Implements add_state(), forward(), reset() and a few other things to handle distributed synchronization and per-step metric computation.

Override update() and compute() functions to implement your own metric. Use add_state() to register metric state variables which keep track of state on each call of update() and are synchronized across processes when compute() is called.


Metric state variables can either be torch.Tensors or an empty list which can we used to store torch.Tensors`.


Different metrics only override update() and not forward(). A call to update() is valid, but it won’t return the metric value at the current step. A call to forward() automatically calls update() and also returns the metric value at the current step.

  • compute_on_step (bool) – Forward only calls update() and returns None if this is set to False. default: True

  • dist_sync_on_step (bool) – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • dist_sync_fn (Optional[Callable]) – Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Adds metric state variable. Only used by subclasses.

  • name (str) – The name of the state variable. The variable will then be accessible at

  • default (Union[list, Tensor]) – Default value of the state; can either be a torch.Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", or "cat", we will use torch.sum, torch.mean, and respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.


Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is torch.Tensor, the synced value will be a stacked torch.Tensor across the process dimension if the metric state was a torch.Tensor. The original torch.Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.


When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", None.

Return type



Make a copy of the metric.

Return type


abstract compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type



Moves all model parameters and buffers to the CPU.

Return type



Moves all model parameters and buffers to the GPU.


device (Union[device, int, None]) – if specified, all parameters will be copied to that device

Return type



Casts all floating point parameters and buffers to double datatype.

Return type



Casts all floating point parameters and buffers to float datatype.

Return type


forward(*args, **kwargs)[source]

Automatically calls update().

Returns the metric value over inputs if compute_on_step is True.

Return type



Casts all floating point parameters and buffers to half datatype.

Return type



Method for post-init to change if metric states should be saved to its state_dict.

Return type



This method automatically resets the metric state variables to their default value.

Return type


state_dict(destination=None, prefix='', keep_vars=False)[source]

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names.


a dictionary containing a whole state of the module

Return type



>>> module.state_dict().keys()
['bias', 'weight']
sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=<function jit_distributed_available>)[source]

Sync function for manually controlling when metrics states should be synced across processes.

  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type


sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=<function jit_distributed_available>)[source]

Context manager to synchronize the states between processes when running in a distributed setting and restore the local cache states after yielding.

  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type


to(*args, **kwargs)[source]

Moves and/or casts the parameters and buffers.

Works similar to but also updates the metrics device and dtype properties

Return type



Casts all parameters and buffers to dst_type.


dst_type (type or string) – the desired type

Return type



Unsync function for manually controlling when metrics states should be reverted back to their local states.


should_unsync (bool) – Whether to perform unsync

Return type


abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type


property device: torch.device[source]

Return the device of the metric.

Return type


property dtype: torch.dtype[source]

Return the dtype of the metric.

Return type


Contributing your metric to Torchmetrics

Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation and tests gets formatted in the following way:

  1. Start by reading our contribution guidelines.

  2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should be put into a single file placed under torchmetrics/functional/"domain"/"new_metric".py where domain is the type of metric (classification, regression, nlp etc) and new_metric is the name of the metric. In this file, there should be the following three functions:

  1. _new_metric_update(...): everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.

  2. _new_metric_compute(...): all remaining logic.

  3. new_metric(...): essentially wraps the _update and _compute private functions into one public function that makes up the functional interface for the metric.


The functional accuracy metric is a great example of this division of logic.

  1. In a corresponding file placed in torchmetrics/"domain"/"new_metric".py create the module interface:

  1. Create a new module metric by subclassing torchmetrics.Metric.

  2. In the __init__ of the module call self.add_state for as many metric states are needed for the metric to proper accumulate metric statistics.

  3. The module interface should essentially call the private _new_metric_update(...) in its update method and similarly the _new_metric_compute(...) function in its compute. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.


The module Accuracy metric that corresponds to the above functional example showcases these steps.

  1. Remember to add binding to the different relevant __init__ files.

  2. Testing is key to keeping torchmetrics trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn, scipy etc).

  1. Create a testing file in tests/"domain"/test_"new_metric".py. Only one file is needed as it is intended to test both the functional and module interface.

  2. In that file, start by defining a number of test inputs that your metric should be evaluated on.

  3. Create a testclass class NewMetric(MetricTester) that inherits from tests.helpers.testers.MetricTester. This testclass should essentially implement the test_"new_metric"_class and test_"new_metric"_fn methods that respectively tests the module interface and the functional interface.

  4. The testclass should be parameterized (using @pytest.mark.parametrize) by the different test inputs defined initially. Additionally, the test_"new_metric"_class method should also be parameterized with an ddp parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these such that different combinations of inputs and parameters gets tested.

  5. (optional) If your metric raises any exception, please add tests that showcase this.


The test file for accuracy metric shows how to implement such tests.

If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]

Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.