inferno.trainers package¶
Subpackages¶
- inferno.trainers.callbacks package
Submodules¶
inferno.trainers.basic module¶
-
class
inferno.trainers.basic.
Trainer
(model=None)[source]¶ Bases: object
A basic trainer.
Given a torch model, this class encapsulates the training and validation loops, checkpoint creation, logging, CPU <-> GPU transfers and managing data-loaders.
In addition, this class interacts with the callback engine (found at inferno.trainers.callbacks.base.CallbackEngine), which manages callbacks at certain preset events.
Notes
Logging is implemented as a special callback, in the sense that it’s jointly managed by the this class and the callback engine. This is primarily because general callbacks are not intended to be serializable, but not being able to serialize the logger is a nuisance.
-
DYNAMIC_STATES
= {'learning_rate': 'current_learning_rate'}¶
-
INF_STRINGS
= {'infty', 'inf', 'infinity'}¶
-
bind_loader
(name, loader, num_inputs=None, num_targets=1)[source]¶ Bind a data loader to the trainer.
Parameters: - name ({'train', 'validate', 'test'}) – Name of the loader, i.e. what it should be used for.
- loader (torch.utils.data.DataLoader) – DataLoader object.
- num_inputs (int) – Number of input tensors from the loader.
- num_targets (int) – Number of target tensors from the loader.
Returns: self
Return type: Raises: - KeyError – if name is invalid.
- TypeError – if loader is not a DataLoader instance.
-
bind_model
(model)[source]¶ Binds a model to the trainer. Equivalent to setting model.
Parameters: model (torch.nn.Module) – Model to bind. Returns: self. Return type: Trainer
-
build_criterion
(method, **kwargs)[source]¶ Builds the loss criterion for training.
Parameters: - method (str or callable or torch.nn.Module) – Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in torch.nn.
- kwargs (dict) – Keyword arguments to the criterion class’ constructor if applicable.
Returns: self.
Return type: Raises: - AssertionError – if criterion is not found.
- NotImplementedError – if method is neither a str nor a callable.
-
build_logger
(logger=None, log_directory=None, **kwargs)[source]¶ Build the logger.
Parameters: - logger (inferno.trainers.callbacks.logging.base.Logger or str or type) – Must either be a Logger object or the name of a logger or the class of a logger.
- log_directory (str) – Path to the directory where the log files are to be stored.
- kwargs (dict) – Keyword arguments to the logger class.
Returns: self
Return type:
-
build_metric
(method, **kwargs)[source]¶ Builds the metric for evaluation.
Parameters: - method (callable or str) – Name of the metric when string, metric class or a callable object when callable. If a name is provided, this method looks for the metric in inferno.extensions.metrics.
- kwargs (dict) – Keyword arguments to the metric class’ constructor, if applicable.
Returns: self.
Return type: Raises: AssertionError: if the metric is not found.
-
build_optimizer
(method, param_groups=None, **kwargs)[source]¶ Builds the optimizer for training.
Parameters: - method (str or callable or torch.optim.Optimizer) – Name of the optimizer when str, handle to the optimizer class when callable, or a torch.optim.Optimizer instance. If a name is provided, this method looks for the optimizer in torch.optim module first and in inferno.extensions.optimizers second.
- param_groups (list of dict) – Specifies the parameter group. Defaults to model.parameters() if None.
- kwargs (dict) – Keyword arguments to the optimizer.
Returns: self.
Return type: Raises: - AssertionError – if optimizer is not found
- NotImplementedError – if method is not str or callable.
-
build_validation_criterion
(method, **kwargs)[source]¶ Builds the loss criterion for validation.
Parameters: - method (str or callable or torch.nn.Module) – Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in torch.nn.
- kwargs (dict) – Keyword arguments to the criterion class’ constructor if applicable.
Returns: self.
Return type: Raises: - AssertionError – if criterion is not found.
- NotImplementedError – if method is neither a str nor a callable.
-
callbacks
¶ Gets the callback engine.
-
console
¶ Get the current console.
-
criterion
¶ Gets the loss criterion.
-
criterion_is_defined
¶
-
cuda
(devices=None, base_device=None)[source]¶ Train on the GPU.
Parameters: - devices (list) – Specify the ordinals of the devices to use for dataparallel training.
- base_device ({'cpu', 'cuda'}) – When using data-parallel training, specify where the result tensors are collected. If ‘cuda’, the results are collected in devices[0].
Returns: self
Return type:
-
current_learning_rate
¶
-
dtype
¶
-
epoch_count
¶
-
evaluate_metric_every
(frequency)[source]¶ Set frequency of metric evaluation __during training__ (and not during validation).
Parameters: frequency (inferno.utils.train_utils.Frequency or str or tuple or list or int) – Metric evaluation frequency. If str, it could be (say) ‘10 iterations’ or ‘1 epoch’. If tuple (or list), it could be (10, ‘iterations’) or (1, ‘epoch’). If int (say 10), it’s interpreted as (10, ‘iterations’). Returns: self Return type: Trainer
-
evaluate_metric_now
¶
-
evaluating_metric_every
¶
-
fetch_next_batch
(from_loader='train', restart_exhausted_generators=True, update_batch_count=True, update_epoch_count_if_generator_exhausted=True)[source]¶
-
fit
(max_num_iterations=None, max_num_epochs=None)[source]¶ Fit model.
Parameters: - max_num_iterations (int or float or str) – (Optional) Maximum number of training iterations. Overrides the value set by Trainer.set_max_num_iterations. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}.
- max_num_epochs (int or float or str) – (Optional) Maximum number of training epochs. Overrides the value set by Trainer.set_max_num_epochs. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}.
Returns: self
Return type:
-
get_current_learning_rate
()[source]¶ Gets the current learning rate. :returns: List of learning rates if there are multiple parameter groups, or a float
if there’s just one.Return type: list or float
-
iteration_count
¶
-
load
(from_directory=None, best=False, filename=None)[source]¶ Load the trainer from checkpoint.
Parameters: - from_directory (str) – Path to the directory where the checkpoint is located. The filename should be ‘checkpoint.pytorch’ if best=False, or ‘best_checkpoint.pytorch’ if best=True.
- best (bool) – Whether to load the best checkpoint. The filename in from_directory should be ‘best_checkpoint.pytorch’.
- filename (str) – Overrides the default filename.
Returns: self
Return type:
-
log_directory
¶ Gets the log directory.
-
logger
¶ Gets the logger.
-
metric
¶ Gets the evaluation metric.
-
metric_is_defined
¶ Checks if the metric is defined.
-
model
¶ Gets the model.
-
model_is_defined
¶
-
optimizer
¶ Gets the optimizer.
-
optimizer_is_defined
¶
-
pickle_module
¶
-
register_callback
(callback, trigger='auto', **callback_kwargs)[source]¶ Registers a callback with the internal callback engine.
Parameters: - callback (type or callable) – Callback to register.
- trigger (str) – Specify the event that triggers the callback. Leave at ‘auto’ to have the callback-engine figure out the triggers. See inferno.training.callbacks.base.CallbackEngine documentation for more on this.
- callback_kwargs (dict) – If callback is a type, initialize an instance with these keywords to the __init__ method.
Returns: self.
Return type:
-
retain_graph
¶
-
save_at_best_validation_score
(yes=True)[source]¶ Sets whether to save when the validation score is the best seen.
-
save_directory
¶
-
save_every
(frequency, to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None)[source]¶ Set checkpoint creation frequency.
Parameters: - frequency (inferno.utils.train_utils.Frequency or tuple or str) – Checkpoint creation frequency. Examples: ‘100 iterations’ or ‘1 epochs’.
- to_directory (str) – Directory where the checkpoints are to be created.
- checkpoint_filename (str) – Name of the checkpoint file.
- best_checkpoint_filename (str) – Name of the best checkpoint file.
Returns: self.
Return type:
-
save_now
¶
-
save_to_directory
(to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None)[source]¶
-
saving_every
¶ Gets the frequency at which checkpoints are made.
-
set_log_directory
(log_directory)[source]¶ Set the directory where the log files are to be stored.
Parameters: log_directory (str) – Directory where the log files are to be stored. Returns: self Return type: Trainer
-
set_max_num_epochs
(max_num_epochs)[source]¶ Set the maximum number of training epochs.
Parameters: max_num_epochs (int or float or str) – Maximum number of training epochs. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}. Returns: self Return type: Trainer
-
set_max_num_iterations
(max_num_iterations)[source]¶ Set the maximum number of training iterations.
Parameters: max_num_iterations (int or float or str) – Maximum number of training iterations. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}. Returns: self Return type: Trainer
-
set_precision
(dtype)[source]¶ Set training precision.
Parameters: dtype ({'double', 'float', 'half'}) – Training precision. Returns: self Return type: Trainer
-
target_batch_dim
¶
-
train_loader
¶
-
validate_every
(frequency, for_num_iterations=None)[source]¶ Set validation frequency.
Parameters: - frequency (inferno.utils.train_utils.Frequency or str or tuple or list or int) – Validation frequency. If str, it could be (say) ‘10 iterations’ or ‘1 epoch’. If tuple (or list), it could be (10, ‘iterations’) or (1, ‘epoch’). If int (say 10), it’s interpreted as (10, ‘iterations’).
- for_num_iterations (int) – Number of iterations to validate for. If not set, the model is validated on the entire dataset (i.e. till the data loader is exhausted).
Returns: self
Return type:
-
validate_for
(num_iterations=None, loader_name='validate')[source]¶ Validate for a given number of validation (if num_iterations is not None) or over the entire (validation) data set.
Parameters: - num_iterations (int) – Number of iterations to validate for. To validate on the entire dataset, leave this as None.
- loader_name (str) – Name of the data loader to use for validation. ‘validate’ is the obvious default.
Returns: self.
Return type:
-
validate_loader
¶
-
validate_now
¶
-
validating_every
¶
-
validation_criterion
¶
-
validation_criterion_is_defined
¶
-
Module contents¶
-
class
inferno.trainers.
Trainer
(model=None)[source]¶ Bases: object
A basic trainer.
Given a torch model, this class encapsulates the training and validation loops, checkpoint creation, logging, CPU <-> GPU transfers and managing data-loaders.
In addition, this class interacts with the callback engine (found at inferno.trainers.callbacks.base.CallbackEngine), which manages callbacks at certain preset events.
Notes
Logging is implemented as a special callback, in the sense that it’s jointly managed by the this class and the callback engine. This is primarily because general callbacks are not intended to be serializable, but not being able to serialize the logger is a nuisance.
-
DYNAMIC_STATES
= {'learning_rate': 'current_learning_rate'}¶
-
INF_STRINGS
= {'infty', 'inf', 'infinity'}¶
-
bind_loader
(name, loader, num_inputs=None, num_targets=1)[source]¶ Bind a data loader to the trainer.
Parameters: - name ({'train', 'validate', 'test'}) – Name of the loader, i.e. what it should be used for.
- loader (torch.utils.data.DataLoader) – DataLoader object.
- num_inputs (int) – Number of input tensors from the loader.
- num_targets (int) – Number of target tensors from the loader.
Returns: self
Return type: Raises: - KeyError – if name is invalid.
- TypeError – if loader is not a DataLoader instance.
-
bind_model
(model)[source]¶ Binds a model to the trainer. Equivalent to setting model.
Parameters: model (torch.nn.Module) – Model to bind. Returns: self. Return type: Trainer
-
build_criterion
(method, **kwargs)[source]¶ Builds the loss criterion for training.
Parameters: - method (str or callable or torch.nn.Module) – Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in torch.nn.
- kwargs (dict) – Keyword arguments to the criterion class’ constructor if applicable.
Returns: self.
Return type: Raises: - AssertionError – if criterion is not found.
- NotImplementedError – if method is neither a str nor a callable.
-
build_logger
(logger=None, log_directory=None, **kwargs)[source]¶ Build the logger.
Parameters: - logger (inferno.trainers.callbacks.logging.base.Logger or str or type) – Must either be a Logger object or the name of a logger or the class of a logger.
- log_directory (str) – Path to the directory where the log files are to be stored.
- kwargs (dict) – Keyword arguments to the logger class.
Returns: self
Return type:
-
build_metric
(method, **kwargs)[source]¶ Builds the metric for evaluation.
Parameters: - method (callable or str) – Name of the metric when string, metric class or a callable object when callable. If a name is provided, this method looks for the metric in inferno.extensions.metrics.
- kwargs (dict) – Keyword arguments to the metric class’ constructor, if applicable.
Returns: self.
Return type: Raises: AssertionError: if the metric is not found.
-
build_optimizer
(method, param_groups=None, **kwargs)[source]¶ Builds the optimizer for training.
Parameters: - method (str or callable or torch.optim.Optimizer) – Name of the optimizer when str, handle to the optimizer class when callable, or a torch.optim.Optimizer instance. If a name is provided, this method looks for the optimizer in torch.optim module first and in inferno.extensions.optimizers second.
- param_groups (list of dict) – Specifies the parameter group. Defaults to model.parameters() if None.
- kwargs (dict) – Keyword arguments to the optimizer.
Returns: self.
Return type: Raises: - AssertionError – if optimizer is not found
- NotImplementedError – if method is not str or callable.
-
build_validation_criterion
(method, **kwargs)[source]¶ Builds the loss criterion for validation.
Parameters: - method (str or callable or torch.nn.Module) – Name of the criterion when str, criterion class when callable, or a torch.nn.Module instance. If a name is provided, this method looks for the criterion in torch.nn.
- kwargs (dict) – Keyword arguments to the criterion class’ constructor if applicable.
Returns: self.
Return type: Raises: - AssertionError – if criterion is not found.
- NotImplementedError – if method is neither a str nor a callable.
-
callbacks
¶ Gets the callback engine.
-
console
¶ Get the current console.
-
criterion
¶ Gets the loss criterion.
-
criterion_is_defined
¶
-
cuda
(devices=None, base_device=None)[source]¶ Train on the GPU.
Parameters: - devices (list) – Specify the ordinals of the devices to use for dataparallel training.
- base_device ({'cpu', 'cuda'}) – When using data-parallel training, specify where the result tensors are collected. If ‘cuda’, the results are collected in devices[0].
Returns: self
Return type:
-
current_learning_rate
¶
-
dtype
¶
-
epoch_count
¶
-
evaluate_metric_every
(frequency)[source]¶ Set frequency of metric evaluation __during training__ (and not during validation).
Parameters: frequency (inferno.utils.train_utils.Frequency or str or tuple or list or int) – Metric evaluation frequency. If str, it could be (say) ‘10 iterations’ or ‘1 epoch’. If tuple (or list), it could be (10, ‘iterations’) or (1, ‘epoch’). If int (say 10), it’s interpreted as (10, ‘iterations’). Returns: self Return type: Trainer
-
evaluate_metric_now
¶
-
evaluating_metric_every
¶
-
fetch_next_batch
(from_loader='train', restart_exhausted_generators=True, update_batch_count=True, update_epoch_count_if_generator_exhausted=True)[source]¶
-
fit
(max_num_iterations=None, max_num_epochs=None)[source]¶ Fit model.
Parameters: - max_num_iterations (int or float or str) – (Optional) Maximum number of training iterations. Overrides the value set by Trainer.set_max_num_iterations. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}.
- max_num_epochs (int or float or str) – (Optional) Maximum number of training epochs. Overrides the value set by Trainer.set_max_num_epochs. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}.
Returns: self
Return type:
-
get_current_learning_rate
()[source]¶ Gets the current learning rate. :returns: List of learning rates if there are multiple parameter groups, or a float
if there’s just one.Return type: list or float
-
iteration_count
¶
-
load
(from_directory=None, best=False, filename=None)[source]¶ Load the trainer from checkpoint.
Parameters: - from_directory (str) – Path to the directory where the checkpoint is located. The filename should be ‘checkpoint.pytorch’ if best=False, or ‘best_checkpoint.pytorch’ if best=True.
- best (bool) – Whether to load the best checkpoint. The filename in from_directory should be ‘best_checkpoint.pytorch’.
- filename (str) – Overrides the default filename.
Returns: self
Return type:
-
log_directory
¶ Gets the log directory.
-
logger
¶ Gets the logger.
-
metric
¶ Gets the evaluation metric.
-
metric_is_defined
¶ Checks if the metric is defined.
-
model
¶ Gets the model.
-
model_is_defined
¶
-
optimizer
¶ Gets the optimizer.
-
optimizer_is_defined
¶
-
pickle_module
¶
-
register_callback
(callback, trigger='auto', **callback_kwargs)[source]¶ Registers a callback with the internal callback engine.
Parameters: - callback (type or callable) – Callback to register.
- trigger (str) – Specify the event that triggers the callback. Leave at ‘auto’ to have the callback-engine figure out the triggers. See inferno.training.callbacks.base.CallbackEngine documentation for more on this.
- callback_kwargs (dict) – If callback is a type, initialize an instance with these keywords to the __init__ method.
Returns: self.
Return type:
-
retain_graph
¶
-
save_at_best_validation_score
(yes=True)[source]¶ Sets whether to save when the validation score is the best seen.
-
save_directory
¶
-
save_every
(frequency, to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None)[source]¶ Set checkpoint creation frequency.
Parameters: - frequency (inferno.utils.train_utils.Frequency or tuple or str) – Checkpoint creation frequency. Examples: ‘100 iterations’ or ‘1 epochs’.
- to_directory (str) – Directory where the checkpoints are to be created.
- checkpoint_filename (str) – Name of the checkpoint file.
- best_checkpoint_filename (str) – Name of the best checkpoint file.
Returns: self.
Return type:
-
save_now
¶
-
save_to_directory
(to_directory=None, checkpoint_filename=None, best_checkpoint_filename=None)[source]¶
-
saving_every
¶ Gets the frequency at which checkpoints are made.
-
set_log_directory
(log_directory)[source]¶ Set the directory where the log files are to be stored.
Parameters: log_directory (str) – Directory where the log files are to be stored. Returns: self Return type: Trainer
-
set_max_num_epochs
(max_num_epochs)[source]¶ Set the maximum number of training epochs.
Parameters: max_num_epochs (int or float or str) – Maximum number of training epochs. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}. Returns: self Return type: Trainer
-
set_max_num_iterations
(max_num_iterations)[source]¶ Set the maximum number of training iterations.
Parameters: max_num_iterations (int or float or str) – Maximum number of training iterations. If float, it should equal numpy.inf. If str, it should be one of {‘inf’, ‘infinity’, ‘infty’}. Returns: self Return type: Trainer
-
set_precision
(dtype)[source]¶ Set training precision.
Parameters: dtype ({'double', 'float', 'half'}) – Training precision. Returns: self Return type: Trainer
-
target_batch_dim
¶
-
train_loader
¶
-
validate_every
(frequency, for_num_iterations=None)[source]¶ Set validation frequency.
Parameters: - frequency (inferno.utils.train_utils.Frequency or str or tuple or list or int) – Validation frequency. If str, it could be (say) ‘10 iterations’ or ‘1 epoch’. If tuple (or list), it could be (10, ‘iterations’) or (1, ‘epoch’). If int (say 10), it’s interpreted as (10, ‘iterations’).
- for_num_iterations (int) – Number of iterations to validate for. If not set, the model is validated on the entire dataset (i.e. till the data loader is exhausted).
Returns: self
Return type:
-
validate_for
(num_iterations=None, loader_name='validate')[source]¶ Validate for a given number of validation (if num_iterations is not None) or over the entire (validation) data set.
Parameters: - num_iterations (int) – Number of iterations to validate for. To validate on the entire dataset, leave this as None.
- loader_name (str) – Name of the data loader to use for validation. ‘validate’ is the obvious default.
Returns: self.
Return type:
-
validate_loader
¶
-
validate_now
¶
-
validating_every
¶
-
validation_criterion
¶
-
validation_criterion_is_defined
¶
-