Concatenation

Module Interface

class torchmetrics.aggregation.CatMetric(nan_strategy='warn', **kwargs)[source]

Concatenate a stream of values.

As input to forward and update the metric accepts the following input

  • value (float or Tensor): a single float or an tensor of float values with arbitrary shape (...,).

As output of forward and compute the metric returns the following output

  • agg (Tensor): scalar float tensor with concatenated values over all input received

Parameters:
  • nan_strategy (Union[str, float]) – options: - 'error': if any nan values are encountered will give a RuntimeError - 'warn': if any nan values are encountered will give a warning and continue - 'ignore': all nan values are silently removed - a float: if a float is provided will impute any nan values with this value

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:

ValueError – If nan_strategy is not one of error, warn, ignore or a float

Example

>>> from torch import tensor
>>> from torchmetrics.aggregation import CatMetric
>>> metric = CatMetric()
>>> metric.update(1)
>>> metric.update(tensor([2, 3]))
>>> metric.compute()
tensor([1., 2., 3.])