.. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_regularized_mnist.py: Regularized MNIST Example ================================ This example demonstrates adding and logging arbitrary regularization losses, in this case, L2 activity regularization and L1 weight regularization. - Add a `_losses` dictionary to any module containing loss names and values - Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses - Call `Trainer.observe_training_and_validation_states` to log the losses as well .. code-block:: python import argparse import sys import torch import torch.nn as nn from torchvision import datasets, transforms from inferno.extensions.layers.reshape import Flatten from inferno.trainers.basic import Trainer from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger class RegularizedLinear(nn.Linear): def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs): super(RegularizedLinear, self).__init__(*args, **kwargs) self.ar_weight = ar_weight self.l1_weight = l1_weight self._losses = {} def forward(self, input): output = super(RegularizedLinear, self).forward(input) self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight return output def model_fn(): return nn.Sequential( Flatten(), RegularizedLinear(in_features=784, out_features=256), nn.LeakyReLU(), RegularizedLinear(in_features=256, out_features=128), nn.LeakyReLU(), RegularizedLinear(in_features=128, out_features=10) ) def mnist_data_loaders(args): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) return train_loader, test_loader def train_model(args): model = model_fn() train_loader, validate_loader = mnist_data_loaders(args) # Build trainer trainer = Trainer(model) \ .build_criterion('RegularizedCrossEntropyLoss') \ .build_metric('CategoricalError') \ .build_optimizer('Adam') \ .validate_every((1, 'epochs')) \ .save_every((1, 'epochs')) \ .save_to_directory(args.save_directory) \ .set_max_num_epochs(args.epochs) \ .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), log_images_every='never'), log_directory=args.save_directory) # Record regularization losses trainer.logger.observe_training_and_validation_states([ 'main_loss', 'total_regularization_loss', 'activity_regularization', 'l1_weight_regularization' ]) # Bind loaders trainer \ .bind_loader('train', train_loader) \ .bind_loader('validate', validate_loader) if args.cuda: trainer.cuda() # Go! trainer.fit() def main(argv): # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--save-directory', type=str, default='output/mnist/v1', help='output directory') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') args = parser.parse_args(argv) args.cuda = not args.no_cuda and torch.cuda.is_available() train_model(args) if __name__ == '__main__': main(sys.argv[1:]) **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_regularized_mnist.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download :download:`Download Python source code: regularized_mnist.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: regularized_mnist.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_