Shortcuts

torchmetrics.Metric

The base Metric class is an abstract base class that are used as the building block for all other Module metrics.

class torchmetrics.Metric(compute_on_step=None, **kwargs)[source]

Base class for all metrics present in the Metrics API.

Implements add_state(), forward(), reset() and a few other things to handle distributed synchronization and per-step metric computation.

Override update() and compute() functions to implement your own metric. Use add_state() to register metric state variables which keep track of state on each call of update() and are synchronized across processes when compute() is called.

Note

Metric state variables can either be torch.Tensors or an empty list which can we used to store torch.Tensors`.

Note

Different metrics only override update() and not forward(). A call to update() is valid, but it won’t return the metric value at the current step. A call to forward() automatically calls update() and also returns the metric value at the current step.

Parameters
  • compute_on_step (Optional[bool]) –

    Forward only calls update() and returns None if this is set to False.

    Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.

  • dist_sync_on_step

    Synchronize metric state across processes at each forward() before returning the value at the step.

    Deprecated since version v0.8: Argument is deprecated and will be removed in v0.9 in favour of instead passing it in as keyword argument.

  • process_group

    Specify the process group on which synchronization is called. Defaults is None which selects the entire world

    Deprecated since version v0.8: Argument is deprecated and will be removed in v0.9 in favour of instead passing it in as keyword argument.

  • dist_sync_fn

    Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.

    Deprecated since version v0.8: Argument is deprecated and will be removed in v0.9 in favour of instead passing it in as keyword argument.

  • kwargs (Dict[str, Any]) –

    additional keyword arguments, see Advanced metric settings for more info.

    • compute_on_cpu: If metric state should be stored on CPU during computations. Only works

      for list states.

    • dist_sync_on_step: If metric state should synchronize on forward()

    • process_group: The process group on which the synchronization is called

    • dist_sync_fn: function that performs the allgather option on the metric state

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Adds metric state variable. Only used by subclasses.

Parameters
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a torch.Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is torch.Tensor, the synced value will be a stacked torch.Tensor across the process dimension if the metric state was a torch.Tensor. The original torch.Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Note

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Raises
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", None.

Return type

None

clone()[source]

Make a copy of the metric.

Return type

Metric

abstract compute()[source]

Override this method to compute the final metric value from state variables synchronized across the distributed backend.

Return type

Any

double()[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

float()[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

forward(*args, **kwargs)[source]

Automatically calls update().

Returns the metric value over inputs if compute_on_step is True.

Return type

Any

half()[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

persistent(mode=False)[source]

Method for post-init to change if metric states should be saved to its state_dict.

Return type

None

reset()[source]

This method automatically resets the metric state variables to their default value.

Return type

None

set_dtype(dst_type)[source]

Special version of type for transferring all metric states to specific dtype :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: Union[str, dtype] :param _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: the desired type :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: type or string

Return type

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Returns

a dictionary containing a whole state of the module

Return type

dict

Example:

>>> module.state_dict().keys()
['bias', 'weight']
sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=<function jit_distributed_available>)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=<function jit_distributed_available>)[source]

Context manager to synchronize the states between processes when running in a distributed setting and restore the local cache states after yielding.

Parameters
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type

Generator

type(dst_type)[source]

Method override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters

should_unsync (bool) – Whether to perform unsync

Return type

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type

None

property device: torch.device[source]

Return the device of the metric.

Return type

device