Implementing a Metric

While we strive to include as many metrics as possible in torchmetrics, we cannot include them all. Therefore, we have made it easy to implement your own metric and possible contribute it to torchmetrics. This page will guide you through the process. If you afterwards are interested in contributing your metric to torchmetrics, please read the contribution guidelines and see this section.

Base interface

To implement your own custom metric, subclass the base Metric class and implement the following methods:

  • __init__(): Each state variable should be called using self.add_state(...).

  • update(): Any code needed to update the state given any inputs to the metric.

  • compute(): Computes a final value from the state of the metric.

We provide the remaining interface, such as reset() that will make sure to correctly reset all metric states that have been added using add_state. You should therefore not implement reset() yourself, only in rare cases where not all the state variables should be reset to their default value. Adding metric states with add_state will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are synchronized across distributed processes, refer to add_state() docs from the base Metric class.

Below is a basic implementation of a custom accuracy metric. In the __init__ method we add the metric states correct and total, which will be used to accumulate the number of correct predictions and the total number of predictions, respectively. In the update method we update the metric states based on the inputs to the metric. Finally, in the compute method we compute the final metric value based on the metric states.

from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:
        preds, target = self._input_format(preds, target)
        if preds.shape != target.shape:
            raise ValueError("preds and target must have the same shape")

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self) -> Tensor:
        return self.correct.float() / self.total

A few important things to note:

  • The dist_reduce_fx argument to add_state is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use "sum" to sum the metric states across batches. A couple of build in options are available: "sum", "mean", "cat", "min" or "max", but a custom reduction is also supported.

  • In update we do not return anything but instead update the metric states in-place.

  • In compute when running in distributed mode, the states would have been synced before the compute method is called. Thus self.correct and self.total will contain the sum of the metric states across all processes.

Working with list states

When initializing metric states with add_state, the default argument can either be a single tensor (as in the example above) or an empty list. Most metric will only require a single tensor to accumulate the metric states, but for some metrics that need access to the individual batch states, it can be useful to use a list of tensors. In the following example we show how to implement Spearman correlation, which requires access to the individual batch states because we need to calculate the rank of the predictions and targets.

from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat

class MySpearmanCorrCoef(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("preds", default=[], dist_reduce_fx="cat")
        self.add_state("target", default=[], dist_reduce_fx="cat")

    def update(self, preds: Tensor, target: Tensor) -> None:
        self.preds.append(preds)
        self.target.append(target)

    def compute(self):
        # parse inputs
        preds = dim_zero_cat(self.preds)
        target = dim_zero_cat(self.target)
        # some intermediate computation...
        r_preds, r_target = _rank_data(preds), _rank_dat(target)
        preds_diff = r_preds - r_preds.mean(0)
        target_diff = r_target - r_target.mean(0)
        cov = (preds_diff * target_diff).mean(0)
        preds_std = torch.sqrt((preds_diff * preds_diff).mean(0))
        target_std = torch.sqrt((target_diff * target_diff).mean(0))
        # finalize the computations
        corrcoef = cov / (preds_std * target_std + eps)
        return torch.clamp(corrcoef, -1.0, 1.0)

A few important things to note for this example:

  • When working with list states, the dist_reduce_fx argument to add_state should be set to "cat" to concatenate the list of tensors across batches.

  • When working with list states, The update(...) method should append the batch states to the list.

  • In the the compute method the list states behave a bit differently dependeding on weather you are running in distributed mode or not. In non-distributed mode the list states will be a list of tensors, while in distributed mode the list have already been concatenated into a single tensor. For this reason, we recommend always using the dim_zero_cat helper function which will standardize the list states to be a single concatenate tensor regardless of the mode.

  • Calling the reset method will clear the list state, deleting any values inserted into it. For this reason, care must be taken when referencing list states. If you require the values after your metric is reset, you must first copy the attribute to another object (e.g. using deepcopy.copy).

Metric attributes

When done implementing your own metric, there are a few properties and attributes that you may want to set to add additional functionality. The three attributes to consider are: is_differentiable, higher_is_better and full_state_update. Note that none of them are strictly required to be set for the metric to work.

from torchmetrics import Metric

class MyMetric(Metric):
    # Set to True if the metric is differentiable else set to False
    is_differentiable: Optional[bool] = None

    # Set to True if the metric reaches it optimal value when the metric is maximized.
    # Set to False if it when the metric is minimized.
    higher_is_better: Optional[bool] = True

    # Set to True if the metric during 'update' requires access to the global metric
    # state for its calculations. If not, setting this to False indicates that all
    # batch states are independent and we will optimize the runtime of 'forward'
    full_state_update: bool = True

Plot interface

From torchmetrics v1.0.0 onwards, we also support plotting of metrics through the .plot() method. By default this method will raise NotImplementedError but can be implemented by the user to provide a custom plot for the metric. For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal ._plot method can be used, that provides the common plotting functionality for most metrics in torchmetrics.

from torchmetrics import Metric
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

class MyMetric(Metric):
    # set these attributes if you want to use the internal ._plot method
    # bounds are automatically added to the generated plot
    plot_lower_bound: Optional[float] = None
    plot_upper_bound: Optional[float] = None

    def plot(
        self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
    ) -> _PLOT_OUT_TYPE:
        return self._plot(val, ax)

If the metric returns a more complex output, a custom implementation of the plot method is required. For more details on the plotting API, see the this page . In addti

Internal implementation details

This section briefly describes how metrics work internally. We encourage looking at the source code for more info. Internally, TorchMetrics wraps the user defined update() and compute() method. We do this to automatically synchronize and reduce metric states across multiple devices. More precisely, calling update() does the following internally:

  1. Clears computed cache.

  2. Calls user-defined update().

Similarly, calling compute() does the following internally:

  1. Syncs metric states between processes.

  2. Reduce gathered metric states.

  3. Calls the user defined compute() method on the gathered metric states.

  4. Cache computed result.

From a user’s standpoint this has one important side-effect: computed results are cached. This means that no matter how many times compute is called after one and another, it will continue to return the same result. The cache is first emptied on the next call to update.

forward serves the dual purpose of both returning the metric on the current data and updating the internal metric state for accumulating over multiple batches. The forward() method achieves this by combining calls to update, compute and reset. Depending on the class property full_state_update, forward can behave in two ways:

  1. If full_state_update is True it indicates that the metric during update requires access to the full metric state and we therefore need to do two calls to update to secure that the metric is calculated correctly

    1. Calls update() to update the global metric state (for accumulation over multiple batches)

    2. Caches the global state.

    3. Calls reset() to clear global metric state.

    4. Calls update() to update local metric state.

    5. Calls compute() to calculate metric for current batch.

    6. Restores the global state.

  2. If full_state_update is False (default) the metric state of one batch is completely independent of the state of other batches, which means that we only need to call update once.

    1. Caches the global state.

    2. Calls reset the metric to its default state

    3. Calls update to update the state with local batch statistics

    4. Calls compute to calculate the metric for the current batch

    5. Reduce the global state and batch state into a single state that becomes the new global state

If implementing your own metric, we recommend trying out the metric with full_state_update class property set to both True and False. If the results are equal, then setting it to False will usually give the best performance.

class torchmetrics.Metric(**kwargs)[source]

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes

The three core methods of the base class are * add_state() * forward() * reset()

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update() * compute()

Parameters:

kwargs (Any) –

additional keyword arguments, see Advanced metric settings for more info.

  • compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step: If metric state should synchronize on forward(). Default is False

  • process_group: The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute: If metric state should synchronize when compute is called. Default is True

  • compute_with_cache: If results from compute should be cached. Default is True

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Add metric state variable. Only used by subclasses.

Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if name is "my_state" then its value can be accessed from an instance metric as metric.my_state. Metric states behave like buffers and parameters of Module as they are also updated when .to() is called. Unlike parameters and buffers, metric states are not by default saved in the modules state_dict.

Parameters:
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Return type:

None

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is Tensor, the synced value will be a stacked Tensor across the process dimension if the metric state was a Tensor. The original Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Note

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Note

The values inserted into a list state are deleted whenever reset() is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values after reset() is called, you must first copy them to another object.

Raises:
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", "min", "max" or None.

clone()[source]

Make a copy of the metric.

Return type:

Metric

abstract compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Return type:

Any

double()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

float()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

forward(*args, **kwargs)[source]

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Parameters:
  • args (Any) – Any arguments as required by the metric update method.

  • kwargs (Any) – Any keyword arguments as required by the metric update method.

Return type:

Any

Returns:

The output of the compute method evaluated on the current batch.

Raises:

TorchMetricsUserError – If the metric is already synced and forward is called again.

half()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

persistent(mode=False)[source]

Change post-init if metric states should be saved to its state_dict.

Return type:

None

plot(*_, **__)[source]

Override this method plot the metric value.

Return type:

Any

reset()[source]

Reset metric state variables to their default value.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as string or dtype object

Return type:

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Get the current state of metric as an dictionary.

Parameters:
  • destination (Optional[Dict[str, Any]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned.

  • prefix (str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.

  • keep_vars (bool) – by default the Tensor returned in the state dict are detached from autograd. If set to True, detaching will not be performed.

Return type:

Dict[str, Any]

sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Raises:

TorchMetricsUserError – If the metric is already synced and sync is called again.

Return type:

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]

Context manager to synchronize states.

This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the synchronized state.

Parameters:
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type:

Generator

type(dst_type)[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters:

should_unsync (bool) – Whether to perform unsync

Return type:

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type:

None

property device: device

Return the device of the metric.

property dtype: dtype

Return the default dtype of the metric.

property metric_state: Dict[str, Union[List[Tensor], Tensor]]

Get the current state of the metric.

property update_called: bool

Returns True if update or forward has been called initialization or last reset.

property update_count: int

Get the number of times update and/or forward has been called since initialization or last reset.

Contributing your metric to TorchMetrics

Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to torchmetrics as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation and tests gets formatted in the following way:

  1. Start by reading our contribution guidelines.

  2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should be put into a single file placed under src/torchmetrics/functional/"domain"/"new_metric".py where domain is the type of metric (classification, regression, text etc.) and new_metric is the name of the metric. In this file, there should be the following three functions:

  1. _new_metric_update(...): everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.

  2. _new_metric_compute(...): all remaining logic.

  3. new_metric(...): essentially wraps the _update and _compute private functions into one public function that makes up the functional interface for the metric.

Note

The functional mean squared error metric is a great example of this division of logic.

  1. In a corresponding file placed in src/torchmetrics/"domain"/"new_metric".py create the module interface:

  1. Create a new module metric by subclassing torchmetrics.Metric.

  2. In the __init__ of the module call self.add_state for as many metric states are needed for the metric to proper accumulate metric statistics.

  3. The module interface should essentially call the private _new_metric_update(...) in its update method and similarly the _new_metric_compute(...) function in its compute. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain.

Note

The module MeanSquaredError metric that corresponds to the above functional example showcases these steps.

  1. Remember to add binding to the different relevant __init__ files.

  2. Testing is key to keeping torchmetrics trustworthy. This is why we have a very rigid testing protocol. This means that we in most cases require the metric to be tested against some other common framework (sklearn, scipy etc).

  1. Create a testing file in tests/unittests/"domain"/test_"new_metric".py. Only one file is needed as it is intended to test both the functional and module interface.

  2. In that file, start by defining a number of test inputs that your metric should be evaluated on.

  3. Create a testclass class NewMetric(MetricTester) that inherits from tests.helpers.testers.MetricTester. This testclass should essentially implement the test_"new_metric"_class and test_"new_metric"_fn methods that respectively tests the module interface and the functional interface.

  4. The testclass should be parameterized (using @pytest.mark.parametrize) by the different test inputs defined initially. Additionally, the test_"new_metric"_class method should also be parameterized with an ddp parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these such that different combinations of inputs and parameters gets tested.

  5. (optional) If your metric raises any exception, please add tests that showcase this.

Note

The test file for MSE metric shows how to implement such tests.

If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied. We will gladly guide and/or help implement the remaining :]