.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_trainer.py: Trainer Example ================================ This example should illustrate how to use the trainer class. .. code-block:: python import torch.nn as nn from inferno.io.box.cifar import get_cifar10_loaders from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger from inferno.extensions.layers import ConvELU2D from inferno.extensions.layers import Flatten from inferno.utils.python_utils import ensure_dir from inferno.extensions.layers import SELU change directories to your needs .. code-block:: python LOG_DIRECTORY = ensure_dir('log') SAVE_DIRECTORY = ensure_dir('save') DATASET_DIRECTORY = ensure_dir('dataset') shall models be downloaded .. code-block:: python DOWNLOAD_CIFAR = True USE_CUDA = True Build torch model .. code-block:: python model = nn.Sequential( ConvELU2D(in_channels=3, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), nn.MaxPool2d(kernel_size=2, stride=2), Flatten(), nn.Linear(in_features=(256 * 4 * 4), out_features=10), nn.Softmax() ) data loaders .. code-block:: python train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY, download=DOWNLOAD_CIFAR) Build trainer .. code-block:: python trainer = Trainer(model) trainer.build_criterion('CrossEntropyLoss') trainer.build_metric('CategoricalError') trainer.build_optimizer('Adam') trainer.validate_every((2, 'epochs')) trainer.save_every((5, 'epochs')) trainer.save_to_directory(SAVE_DIRECTORY) trainer.set_max_num_epochs(10) trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every='never'), log_directory=LOG_DIRECTORY) Bind loaders .. code-block:: python trainer.bind_loader('train', train_loader) trainer.bind_loader('validate', validate_loader) activate cuda .. code-block:: python if USE_CUDA: trainer.cuda() fit .. code-block:: python trainer.fit() **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_trainer.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: trainer.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: trainer.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_