torchmetrics.Metric¶
The base Metric
class is an abstract base class that are used as the building block for all other Module
metrics.
- class torchmetrics.Metric(**kwargs)[source]
Base class for all metrics present in the Metrics API.
Implements
add_state()
,forward()
,reset()
and a few other things to handle distributed synchronization and per-step metric computation.Override
update()
andcompute()
functions to implement your own metric. Useadd_state()
to register metric state variables which keep track of state on each call ofupdate()
and are synchronized across processes whencompute()
is called.Note
Metric state variables can either be
torch.Tensors
or an empty list which can we used to store torch.Tensors`.Note
Different metrics only override
update()
and notforward()
. A call toupdate()
is valid, but it won’t return the metric value at the current step. A call toforward()
automatically callsupdate()
and also returns the metric value at the current step.- Parameters
additional keyword arguments, see Advanced metric settings for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works
for list states.
dist_sync_on_step: If metric state should synchronize on
forward()
. Default isFalse
process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: function that performs the allgather option on the metric state. Default is an
custom implementation that calls
torch.distributed.all_gather
internally.
sync_on_compute: If metric state should synchronize when
compute
is called. Default isTrue
-
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]
Adds metric state variable. Only used by subclasses.
- Parameters
name¶ (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default¶ (
Union
[list
,Tensor
]) – Default value of the state; can either be atorch.Tensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx¶ (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
,"cat"
,"min"
or"max"
we will usetorch.sum
,torch.mean
,torch.cat
,torch.min
andtorch.max`
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent¶ (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
torch.Tensor
, the synced value will be a stackedtorch.Tensor
across the process dimension if the metric state was atorch.Tensor
. The originaltorch.Tensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.- Raises
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,None
.
- Return type
- abstract compute()[source]
Override this method to compute the final metric value from state variables synchronized across the distributed backend.
- Return type
- double()[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- float()[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- forward(*args, **kwargs)[source]
forward
serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumululating metric state.Input arguments are the exact same as corresponding
update
method. The returned output is the exact same as the output ofcompute
.- Return type
- half()[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- persistent(mode=False)[source]
Method for post-init to change if metric states should be saved to its state_dict.
- Return type
- reset()[source]
This method automatically resets the metric state variables to their default value.
- Return type
- set_dtype(dst_type)[source]
Special version of type for transferring all metric states to specific dtype :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type:
Union
[str
,dtype
] :param _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: the desired type :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: type or string- Return type
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.Warning
Currently
state_dict()
also accepts positional arguments fordestination
,prefix
andkeep_vars
in order. However, this is being deprecated and keyword arguments will be enforced in future releases.Warning
Please avoid the use of argument
destination
as it is not designed for end-users.- Parameters
destination¶ (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDict
will be created and returned. Default:None
.prefix¶ (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default:
''
.keep_vars¶ (bool, optional) – by default the
Tensor
s returned in the state dict are detached from autograd. If it’s set toTrue
, detaching will not be performed. Default:False
.
- Returns
a dictionary containing a whole state of the module
- Return type
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=<function jit_distributed_available>)[source]
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=<function jit_distributed_available>)[source]
Context manager to synchronize the states between processes when running in a distributed setting and restore the local cache states after yielding.
- Parameters
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync¶ (
bool
) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type
- type(dst_type)[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- unsync(should_unsync=True)[source]
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]
Override this method to update the state variables of your metric class.
- Return type
- property device: torch.device[source]
Return the device of the metric.
- Return type