inferno.extensions.criteria package¶
Submodules¶
inferno.extensions.criteria.core module¶
-
class
inferno.extensions.criteria.core.
Criteria
(*criteria)[source]¶ Bases: torch.nn.modules.module.Module
Aggregate multiple criteria to one.
-
forward
(prediction, target)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
inferno.extensions.criteria.core.
As2DCriterion
(criterion)[source]¶ Bases: torch.nn.modules.module.Module
Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors, if they’re applicable to (N, C) prediction and (N,) target tensors .
-
forward
(prediction, target)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
inferno.extensions.criteria.elementwise_measures module¶
-
class
inferno.extensions.criteria.elementwise_measures.
WeightedMSELoss
(positive_class_weight=1.0, positive_class_value=1.0, size_average=True)[source]¶ Bases: torch.nn.modules.module.Module
-
NEGATIVE_CLASS_WEIGHT
= 1.0¶
-
forward
(input, target)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
inferno.extensions.criteria.regularized module¶
-
class
inferno.extensions.criteria.regularized.
RegularizedLoss
(criterion, *args, **kwargs)[source]¶ Bases: torch.nn.modules.module.Module
Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.
-
forward
(*args, trainer=None, model=None, **kwargs)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
inferno.extensions.criteria.regularized.
RegularizedCrossEntropyLoss
(*args, **kwargs)[source]¶ Bases: inferno.extensions.criteria.regularized.RegularizedLoss
-
class
inferno.extensions.criteria.regularized.
RegularizedBCEWithLogitsLoss
(*args, **kwargs)[source]¶ Bases: inferno.extensions.criteria.regularized.RegularizedLoss
-
class
inferno.extensions.criteria.regularized.
RegularizedBCELoss
(*args, **kwargs)[source]¶ Bases: inferno.extensions.criteria.regularized.RegularizedLoss
-
class
inferno.extensions.criteria.regularized.
RegularizedMSELoss
(*args, **kwargs)[source]¶ Bases: inferno.extensions.criteria.regularized.RegularizedLoss
-
class
inferno.extensions.criteria.regularized.
RegularizedNLLLoss
(*args, **kwargs)[source]¶ Bases: inferno.extensions.criteria.regularized.RegularizedLoss
inferno.extensions.criteria.set_similarity_measures module¶
-
class
inferno.extensions.criteria.set_similarity_measures.
SorensenDiceLoss
(weight=None, channelwise=True, eps=1e-06)[source]¶ Bases: torch.nn.modules.module.Module
Computes a loss scalar, which when minimized maximizes the Sorensen-Dice similarity between the input and the target. For both inputs and targets it must be the case that input_or_target.size(1) = num_channels.
-
class
inferno.extensions.criteria.set_similarity_measures.
GeneralizedDiceLoss
(weight=None, channelwise=False, eps=1e-06)[source]¶ Bases: torch.nn.modules.module.Module
Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237
This version works for multiple classes and expects predictions for every class (e.g. softmax output) and one-hot targets for every class.