Hinge Loss¶
Module Interface¶
- class torchmetrics.HingeLoss(task: Literal['binary', 'multiclass'], num_classes: Optional[int] = None, squared: bool = False, multiclass_mode: Optional[Literal['crammer-singer', 'one-vs-all']] = 'crammer-singer', ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any)[source]
Computes the mean Hinge loss typically used for Support Vector Machines (SVMs).
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofBinaryHingeLoss
andMulticlassHingeLoss
for the specific details of each argument influence and examples.- Legacy Example:
>>> import torch >>> target = torch.tensor([0, 1, 1]) >>> preds = torch.tensor([0.5, 0.7, 0.1]) >>> hinge = HingeLoss(task="binary") >>> hinge(preds, target) tensor(0.9000)
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = HingeLoss(task="multiclass", num_classes=3) >>> hinge(preds, target) tensor(1.5551)
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all") >>> hinge(preds, target) tensor([1.3743, 1.1945, 1.2359])
BinaryHingeLoss¶
- class torchmetrics.classification.BinaryHingeLoss(squared=False, ignore_index=None, validate_args=True, **kwargs)[source]
Computes the mean Hinge loss typically used for Support Vector Machines (SVMs) for binary tasks. It is defined as:
Where
is the target, and
is the prediction.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:bhl
(Tensor
): A tensor containing the hinge loss.
- Parameters
squared¶ (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.classification import BinaryHingeLoss >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> bhl = BinaryHingeLoss() >>> bhl(preds, target) tensor(0.6900) >>> bhl = BinaryHingeLoss(squared=True) >>> bhl(preds, target) tensor(0.6905)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
MulticlassHingeLoss¶
- class torchmetrics.classification.MulticlassHingeLoss(num_classes, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True, **kwargs)[source]
Computes the mean Hinge loss typically used for Support Vector Machines (SVMs) for multiclass tasks.
The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
Where
is the target class (where
is the number of classes), and
is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): A float tensor of shape(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(Tensor
): An int tensor of shape(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Note
Additional dimension
...
will be flattened into the batch dimension.As output to
forward
andcompute
the metric returns the following output:mchl
(Tensor
): A tensor containing the multi-class hinge loss.
- Parameters
num_classes¶ (
int
) – Integer specifing the number of classessquared¶ (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.multiclass_mode¶ (
Literal
[‘crammer-singer’, ‘one-vs-all’]) – Determines how to compute the metricignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.classification import MulticlassHingeLoss >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = torch.tensor([0, 1, 2, 0]) >>> mchl = MulticlassHingeLoss(num_classes=3) >>> mchl(preds, target) tensor(0.9125) >>> mchl = MulticlassHingeLoss(num_classes=3, squared=True) >>> mchl(preds, target) tensor(1.1131) >>> mchl = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all') >>> mchl(preds, target) tensor([0.8750, 1.1250, 1.1000])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Functional Interface¶
- torchmetrics.functional.hinge_loss(preds, target, task, num_classes=None, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=True)[source]
Computes the mean Hinge loss typically used for Support Vector Machines (SVMs).
This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the
task
argument to either'binary'
or'multiclass'
. See the documentation ofbinary_hinge_loss()
andmulticlass_hinge_loss()
for the specific details of each argument influence and examples.- Legacy Example:
>>> import torch >>> target = torch.tensor([0, 1, 1]) >>> preds = torch.tensor([0.5, 0.7, 0.1]) >>> hinge_loss(preds, target, task="binary") tensor(0.9000)
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge_loss(preds, target, task="multiclass", num_classes=3) tensor(1.5551)
>>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") tensor([1.3743, 1.1945, 1.2359])
- Return type
binary_hinge_loss¶
- torchmetrics.functional.classification.binary_hinge_loss(preds, target, squared=False, ignore_index=None, validate_args=False)[source]
Computes the mean Hinge loss typically used for Support Vector Machines (SVMs) for binary tasks. It is defined as:
Where
is the target, and
is the prediction.
Accepts the following input tensors:
preds
(float tensor):(N, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if ignore_index is specified). The value 1 always encodes the positive class.
Additional dimension
...
will be flattened into the batch dimension.- Parameters
squared¶ (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.ignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
Example
>>> from torchmetrics.functional.classification import binary_hinge_loss >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> binary_hinge_loss(preds, target) tensor(0.6900) >>> binary_hinge_loss(preds, target, squared=True) tensor(0.6905)
- Return type
multiclass_hinge_loss¶
- torchmetrics.functional.classification.multiclass_hinge_loss(preds, target, num_classes, squared=False, multiclass_mode='crammer-singer', ignore_index=None, validate_args=False)[source]
Computes the mean Hinge loss typically used for Support Vector Machines (SVMs) for multiclass tasks.
The metric can be computed in two ways. Either, the definition by Crammer and Singer is used:
Where
is the target class (where
is the number of classes), and
is the predicted output per class. Alternatively, the metric can also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion.
Accepts the following input tensors:
preds
(float tensor):(N, C, ...)
. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample.target
(int tensor):(N, ...)
. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if ignore_index is specified).
Additional dimension
...
will be flattened into the batch dimension.- Parameters
num_classes¶ (
int
) – Integer specifing the number of classessquared¶ (
bool
) – If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss.multiclass_mode¶ (
Literal
[‘crammer-singer’, ‘one-vs-all’]) – Determines how to compute the metricignore_index¶ (
Optional
[int
]) – Specifies a target value that is ignored and does not contribute to the metric calculationvalidate_args¶ (
bool
) – bool indicating if input arguments and tensors should be validated for correctness. Set toFalse
for faster computations.
Example
>>> from torchmetrics.functional.classification import multiclass_hinge_loss >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], ... [0.90, 0.05, 0.05]]) >>> target = torch.tensor([0, 1, 2, 0]) >>> multiclass_hinge_loss(preds, target, num_classes=3) tensor(0.9125) >>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True) tensor(1.1131) >>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all') tensor([0.8750, 1.1250, 1.1000])
- Return type