Feature Sharing

Module Interface

class torchmetrics.wrappers.FeatureShare(metrics, max_cache_size=None)[source]

Specialized metric collection that facilitates sharing features between metrics.

Certain metrics rely on an underlying expensive neural network for feature extraction when computing the metric. This wrapper allows to share the feature extraction between multiple metrics, which can save a lot of time and memory. This is achieved by making a shared instance of the network between the metrics and secondly by caching the input-output pairs of the network, such the subsequent calls to the network with the same input will be much faster.

Parameters:
  • metrics (Union[Metric, Sequence[Metric], Dict[str, Metric]]) –

    One of the following:

    • list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name as key for output dict. Therefore, two metrics of the same class cannot be chained this way.

    • dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict. Use this format if you want to chain together multiple of the same metric with different parameters. Note that the keys in the output dict will be sorted alphabetically.

  • max_cache_size (Optional[int]) – maximum number of input-output pairs to cache per metric. By default, this is none which means that the cache will be set to the number of metrics in the collection meaning that all features will be cached and shared across all metrics per batch.

Example::
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.wrappers import FeatureShare
>>> from torchmetrics.image import FrechetInceptionDistance, KernelInceptionDistance
>>> # initialize the metrics
>>> fs = FeatureShare([FrechetInceptionDistance(), KernelInceptionDistance(subset_size=10, subsets=2)])
>>> # update metric
>>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=True)
>>> fs.update(torch.randint(255, (50, 3, 64, 64), dtype=torch.uint8), real=False)
>>> # compute metric
>>> fs.compute()
{'FrechetInceptionDistance': tensor(15.1700), 'KernelInceptionDistance': (tensor(-0.0012), tensor(0.0014))}