Shortcuts

# KL Divergence¶

## Module Interface¶

class torchmetrics.KLDivergence(log_prob=False, reduction='mean', **kwargs)[source]

Compute the KL divergence.

$D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}$

Where $$P$$ and $$Q$$ are probability distributions where $$P$$ usually represents a distribution over data and $$Q$$ is often a prior or approximation of $$P$$. It should be noted that the KL divergence is a non-symetrical metric i.e. $$D_{KL}(P||Q) \neq D_{KL}(Q||P)$$.

As input to forward and update the metric accepts the following input:

As output of forward and compute the metric returns the following output:

Parameters:
• log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.

• reduction (Literal['mean', 'sum', 'none', None]) –

Determines how to reduce over the N/batch dimension:

• 'mean' [default]: Averages score across samples

• 'sum': Sum score across samples

• 'none' or None: Returns score per sample

• kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
• TypeError – If log_prob is not an bool.

• ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

Note

Half precision is only support on GPU for this metric

Example

>>> from torch import tensor
>>> from torchmetrics.regression import KLDivergence
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence = KLDivergence()
>>> kl_divergence(p, q)
tensor(0.0853)

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
Return type:
Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
>>> fig_, ax_ = metric.plot()

>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)))
>>> fig, ax = metric.plot(values)


## Functional Interface¶

torchmetrics.functional.kl_divergence(p, q, log_prob=False, reduction='mean')[source]

Compute KL divergence.

$D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}$

Where $$P$$ and $$Q$$ are probability distributions where $$P$$ usually represents a distribution over data and $$Q$$ is often a prior or approximation of $$P$$. It should be noted that the KL divergence is a non-symetrical metric i.e. $$D_{KL}(P||Q) \neq D_{KL}(Q||P)$$.

Parameters:
• p (Tensor) – data distribution with shape [N, d]

• q (Tensor) – prior or approximate distribution with shape [N, d]

• log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1

• reduction (Literal['mean', 'sum', 'none', None]) –

Determines how to reduce over the N/batch dimension:

• 'mean' [default]: Averages score across samples

• 'sum': Sum score across samples

• 'none' or None: Returns score per sample

Return type:

Tensor

Example

>>> from torch import tensor
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence(p, q)
tensor(0.0853)


© Copyright Copyright (c) 2020-2023, Lightning-AI et al... Revision b57bb6d3.

Built with Sphinx using a theme provided by Read the Docs.