Binned Precision Recall Curve¶
Module Interface¶
- class torchmetrics.BinnedPrecisionRecallCurve(num_classes, thresholds=100, compute_on_step=None, **kwargs)[source]
Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach.
Computation is performed in constant-memory by computing precision and recall for
thresholds
buckets/thresholds (evenly distributed between 0 and 1).Forward accepts
preds
(float tensor):(N, ...)
(binary) or(N, C, ...)
(multiclass) tensor with probabilities, where C is the number of classes.target
(long tensor):(N, ...)
or(N, C, ...)
with integer labels
- Parameters
num_classes¶ (
int
) – integer with number of classes. For binary, set to 1.thresholds¶ (
Union
[int
,Tensor
,List
[float
]]) – list or tensor with specific thresholds or a number of bins from linear sampling. It is used for computation will lead to more detailed curve and accurate estimates, but will be slower and consume more memory.compute_on_step¶ (
Optional
[bool
]) –Forward only calls
update()
and returns None if this is set to False.Deprecated since version v0.8: Argument has no use anymore and will be removed v0.9.
kwargs¶ (
Dict
[str
,Any
]) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises
ValueError – If
thresholds
is not aint
,list
ortensor
- Example (binary case):
>>> from torchmetrics import BinnedPrecisionRecallCurve >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, thresholds=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000]) >>> recall tensor([1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000]) >>> thresholds tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
- Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, thresholds=3) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision [tensor([0.2500, 1.0000, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000, 1.0000]), tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]), tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]), tensor([2.5000e-07, 1.0000e+00, 1.0000e+00, 1.0000e+00])] >>> recall [tensor([1.0000, 1.0000, 0.0000, 0.0000]), tensor([1.0000, 1.0000, 0.0000, 0.0000]), tensor([1.0000, 0.0000, 0.0000, 0.0000]), tensor([1.0000, 0.0000, 0.0000, 0.0000]), tensor([0., 0., 0., 0.])] >>> thresholds [tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000]), tensor([0.0000, 0.5000, 1.0000])]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- compute()[source]
Returns float tensor of size n_classes.