Shortcuts

torchmetrics.Metric

The base Metric class is an abstract base class that are used as the building block for all other Module metrics.

class torchmetrics.Metric(**kwargs)[source]

Base class for all metrics present in the Metrics API.

Implements add_state(), forward(), reset() and a few other things to handle distributed synchronization and per-step metric computation.

Override update() and compute() functions to implement your own metric. Use add_state() to register metric state variables which keep track of state on each call of update() and are synchronized across processes when compute() is called.

Note

Metric state variables can either be Tensor or an empty list which can we used to store Tensor.

Note

Different metrics only override update() and not forward(). A call to update() is valid, but it won’t return the metric value at the current step. A call to forward() automatically calls update() and also returns the metric value at the current step.

Parameters

kwargs (Any) –

additional keyword arguments, see Advanced metric settings for more info.

  • compute_on_cpu: If metric state should be stored on CPU during computations. Only works

    for list states.

  • dist_sync_on_step: If metric state should synchronize on forward(). Default is False

  • process_group: The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn: function that performs the allgather option on the metric state. Default is an

    custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn: function that checks if the distributed backend is available.

    Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute: If metric state should synchronize when compute is called. Default is True

  • compute_with_cache: If results from compute should be cached. Default is False

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Add metric state variable. Only used by subclasses.

Parameters
  • name (str) – The name of the state variable. The variable will then be accessible at self.name.

  • default (Union[list, Tensor]) – Default value of the state; can either be a Tensor or an empty list. The state will be reset to this value when self.reset() is called.

  • dist_reduce_fx (Optional) – Function to reduce state across multiple processes in distributed mode. If value is "sum", "mean", "cat", "min" or "max" we will use torch.sum, torch.mean, torch.cat, torch.min and torch.max` respectively, each with argument dim=0. Note that the "cat" reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.

  • persistent (Optional) – whether the state will be saved as part of the modules state_dict. Default is False.

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

  • If the metric state is Tensor, the synced value will be a stacked Tensor across the process dimension if the metric state was a Tensor. The original Tensor metric state retains dimension and hence the synchronized output will be of shape (num_process, ...).

  • If the metric state is a list, the synced value will be a list containing the combined elements from all processes.

Note

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Raises
  • ValueError – If default is not a tensor or an empty list.

  • ValueError – If dist_reduce_fx is not callable or one of "mean", "sum", "cat", None.

Return type

None

clone()[source]

Make a copy of the metric.

Return type

Metric

abstract compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Return type

Any

double()[source]

Override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

float()[source]

Override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

forward(*args, **kwargs)[source]

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state. Input arguments are the exact same as corresponding update method. The returned output is the exact same as the output of compute.

Return type

Any

half()[source]

Override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

persistent(mode=False)[source]

Change post-init if metric states should be saved to its state_dict.

Return type

None

plot(*_, **__)[source]

Override this method plot the metric value.

Return type

Any

reset()[source]

Reset metric state variables to their default value.

Return type

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters

dst_type (type or string) – the desired type.

Return type

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Get the current state of metric as an dictionary.

Parameters
  • destination (Optional[Dict[str, Any]]) – Optional dictionary, that if provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned.

  • prefix (str) – optional string, a prefix added to parameter and buffer names to compose the keys in state_dict.

  • keep_vars (bool) – by default the If set to ``True`, detaching will not be performed.

Return type

Dict[str, Any]

sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]

Context manager to synchronize states.

This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the syncronized state.

Parameters
  • dist_sync_fn (Optional[Callable]) – Function to be used to perform states synchronization

  • process_group (Optional[Any]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)

  • should_sync (bool) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.

  • should_unsync (bool) – Whether to restore the cache state so that the metrics can continue to be accumulated.

  • distributed_available (Optional[Callable]) – Function to determine if we are running inside a distributed setting

Return type

Generator

type(dst_type)[source]

Override default and prevent dtype casting.

Please use metric.set_dtype(dtype) instead.

Return type

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters

should_unsync (bool) – Whether to perform unsync

Return type

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type

None

property device: torch.device

Return the device of the metric.

Return type

device

property update_called: bool

Returns True if update or forward has been called initialization or last reset.

Return type

bool

property update_count: int

Get the number of times update and/or forward has been called since initialization or last reset.

Return type

int

Read the Docs v: latest
Versions
latest
stable
v0.11.4
v0.11.3
v0.11.2
v0.11.1
v0.11.0
v0.10.3
v0.10.2
v0.10.1
v0.10.0
v0.9.3
v0.9.2
v0.9.1
v0.9.0
v0.8.2
v0.8.1
v0.8.0
v0.7.3
v0.7.2
v0.7.1
v0.7.0
v0.6.2
v0.6.1
v0.6.0
v0.5.1
v0.5.0
v0.4.1
v0.4.0
v0.3.2
v0.3.1
v0.3.0
v0.2.0
v0.1.0
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.