inferno.extensions.criteria package

Submodules

inferno.extensions.criteria.core module

class inferno.extensions.criteria.core.Criteria(*criteria)[source]

Bases: torch.nn.modules.module.Module

Aggregate multiple criteria to one.

forward(prediction, target)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class inferno.extensions.criteria.core.As2DCriterion(criterion)[source]

Bases: torch.nn.modules.module.Module

Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors, if they’re applicable to (N, C) prediction and (N,) target tensors .

forward(prediction, target)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inferno.extensions.criteria.elementwise_measures module

class inferno.extensions.criteria.elementwise_measures.WeightedMSELoss(positive_class_weight=1.0, positive_class_value=1.0, size_average=True)[source]

Bases: torch.nn.modules.module.Module

NEGATIVE_CLASS_WEIGHT = 1.0
forward(input, target)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inferno.extensions.criteria.regularized module

class inferno.extensions.criteria.regularized.RegularizedLoss(criterion, *args, **kwargs)[source]

Bases: torch.nn.modules.module.Module

Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.

forward(*args, trainer=None, model=None, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class inferno.extensions.criteria.regularized.RegularizedCrossEntropyLoss(*args, **kwargs)[source]

Bases: inferno.extensions.criteria.regularized.RegularizedLoss

class inferno.extensions.criteria.regularized.RegularizedBCEWithLogitsLoss(*args, **kwargs)[source]

Bases: inferno.extensions.criteria.regularized.RegularizedLoss

class inferno.extensions.criteria.regularized.RegularizedBCELoss(*args, **kwargs)[source]

Bases: inferno.extensions.criteria.regularized.RegularizedLoss

class inferno.extensions.criteria.regularized.RegularizedMSELoss(*args, **kwargs)[source]

Bases: inferno.extensions.criteria.regularized.RegularizedLoss

class inferno.extensions.criteria.regularized.RegularizedNLLLoss(*args, **kwargs)[source]

Bases: inferno.extensions.criteria.regularized.RegularizedLoss

inferno.extensions.criteria.set_similarity_measures module

class inferno.extensions.criteria.set_similarity_measures.SorensenDiceLoss(weight=None, channelwise=True, eps=1e-06)[source]

Bases: torch.nn.modules.module.Module

Computes a loss scalar, which when minimized maximizes the Sorensen-Dice similarity between the input and the target. For both inputs and targets it must be the case that input_or_target.size(1) = num_channels.

forward(input, target)[source]

input: torch.FloatTensor or torch.cuda.FloatTensor target: torch.FloatTensor or torch.cuda.FloatTensor

Expected shape of the inputs: (batch_size, nb_channels, …)

class inferno.extensions.criteria.set_similarity_measures.GeneralizedDiceLoss(weight=None, channelwise=False, eps=1e-06)[source]

Bases: torch.nn.modules.module.Module

Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237

This version works for multiple classes and expects predictions for every class (e.g. softmax output) and one-hot targets for every class.

forward(input, target)[source]

input: torch.FloatTensor or torch.cuda.FloatTensor target: torch.FloatTensor or torch.cuda.FloatTensor

Expected shape of the inputs:
  • if not channelwise: (batch_size, nb_classes, …)
  • if channelwise: (batch_size, nb_channels, nb_classes, …)

Module contents