torchmetrics.Metric¶
The base Metric
class is an abstract base class that are used as the building block for all other Module
metrics.
- class torchmetrics.Metric(compute_on_step=None, **kwargs)[source]
Base class for all metrics present in the Metrics API.
Implements
add_state()
,forward()
,reset()
and a few other things to handle distributed synchronization and per-step metric computation.Override
update()
andcompute()
functions to implement your own metric. Useadd_state()
to register metric state variables which keep track of state on each call ofupdate()
and are synchronized across processes whencompute()
is called.Note
Metric state variables can either be
torch.Tensors
or an empty list which can we used to store torch.Tensors`.Note
Different metrics only override
update()
and notforward()
. A call toupdate()
is valid, but it won’t return the metric value at the current step. A call toforward()
automatically callsupdate()
and also returns the metric value at the current step.- Parameters
compute_on_step¶ (
Optional
[bool
]) –Forward only calls
update()
and returns None if this is set to False.Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.
dist_sync_on_step¶ –
Synchronize metric state across processes at each
forward()
before returning the value at the step.Deprecated since version v0.8: Argument is deprecated and will be removed in v0.9 in favour of instead passing it in as keyword argument.
process_group¶ –
Specify the process group on which synchronization is called. Defaults is None which selects the entire world
Deprecated since version v0.8: Argument is deprecated and will be removed in v0.9 in favour of instead passing it in as keyword argument.
dist_sync_fn¶ –
Callback that performs the allgather operation on the metric state. When None, DDP will be used to perform the allgather.
Deprecated since version v0.8: Argument is deprecated and will be removed in v0.9 in favour of instead passing it in as keyword argument.
additional keyword arguments, see Advanced metric settings for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works
for list states.
dist_sync_on_step: If metric state should synchronize on
forward()
process_group: The process group on which the synchronization is called
dist_sync_fn: function that performs the allgather option on the metric state
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- add_state(name, default, dist_reduce_fx=None, persistent=False)[source]
Adds metric state variable. Only used by subclasses.
- Parameters
name¶ (
str
) – The name of the state variable. The variable will then be accessible atself.name
.default¶ (
Union
[list
,Tensor
]) – Default value of the state; can either be atorch.Tensor
or an empty list. The state will be reset to this value whenself.reset()
is called.dist_reduce_fx¶ (Optional) – Function to reduce state across multiple processes in distributed mode. If value is
"sum"
,"mean"
,"cat"
,"min"
or"max"
we will usetorch.sum
,torch.mean
,torch.cat
,torch.min
andtorch.max`
respectively, each with argumentdim=0
. Note that the"cat"
reduction only makes sense if the state is a list, and not a tensor. The user can also pass a custom function in this parameter.persistent¶ (Optional) – whether the state will be saved as part of the modules
state_dict
. Default isFalse
.
Note
Setting
dist_reduce_fx
to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.The metric states would be synced as follows
If the metric state is
torch.Tensor
, the synced value will be a stackedtorch.Tensor
across the process dimension if the metric state was atorch.Tensor
. The originaltorch.Tensor
metric state retains dimension and hence the synchronized output will be of shape(num_process, ...)
.If the metric state is a
list
, the synced value will be alist
containing the combined elements from all processes.
Note
When passing a custom function to
dist_reduce_fx
, expect the synchronized metric state to follow the format discussed in the above note.- Raises
ValueError – If
default
is not atensor
or anempty list
.ValueError – If
dist_reduce_fx
is not callable or one of"mean"
,"sum"
,"cat"
,None
.
- Return type
- abstract compute()[source]
Override this method to compute the final metric value from state variables synchronized across the distributed backend.
- Return type
- double()[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- float()[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- forward(*args, **kwargs)[source]
Automatically calls
update()
.Returns the metric value over inputs if
compute_on_step
is True.- Return type
- half()[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- persistent(mode=False)[source]
Method for post-init to change if metric states should be saved to its state_dict.
- Return type
- reset()[source]
This method automatically resets the metric state variables to their default value.
- Return type
- set_dtype(dst_type)[source]
Special version of type for transferring all metric states to specific dtype :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type:
Union
[str
,dtype
] :param _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: the desired type :type _sphinx_paramlinks_torchmetrics.Metric.set_dtype.dst_type: type or string- Return type
- state_dict(destination=None, prefix='', keep_vars=False)[source]
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to
None
are not included.- Returns
a dictionary containing a whole state of the module
- Return type
Example:
>>> module.state_dict().keys() ['bias', 'weight']
- sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=<function jit_distributed_available>)[source]
Sync function for manually controlling when metrics states should be synced across processes.
- Parameters
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type
- sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=<function jit_distributed_available>)[source]
Context manager to synchronize the states between processes when running in a distributed setting and restore the local cache states after yielding.
- Parameters
dist_sync_fn¶ (
Optional
[Callable
]) – Function to be used to perform states synchronizationprocess_group¶ (
Optional
[Any
]) – Specify the process group on which synchronization is called. default: None (which selects the entire world)should_sync¶ (
bool
) – Whether to apply to state synchronization. This will have an impact only when running in a distributed setting.should_unsync¶ (
bool
) – Whether to restore the cache state so that the metrics can continue to be accumulated.distributed_available¶ (
Optional
[Callable
]) – Function to determine if we are running inside a distributed setting
- Return type
- type(dst_type)[source]
Method override default and prevent dtype casting.
Please use metric.set_dtype(dtype) instead.
- Return type
- unsync(should_unsync=True)[source]
Unsync function for manually controlling when metrics states should be reverted back to their local states.
- abstract update(*_, **__)[source]
Override this method to update the state variables of your metric class.
- Return type
- property device: torch.device[source]
Return the device of the metric.
- Return type
device