Train Side Loss UNet ExampleΒΆ

In this example a UNet with side supervision and auxiliary loss implemented

Imports needed for this example

import torch
import torch.nn as nn
from inferno.io.box.binary_blobs import get_binary_blob_loaders
from inferno.trainers.basic import Trainer

from inferno.extensions.layers.convolutional import  Conv2D
from inferno.extensions.layers.building_blocks import ResBlock
from inferno.extensions.layers import ResBlockUNet
from inferno.utils.torch_utils import unwrap
from inferno.utils.python_utils import ensure_dir
import pylab

To create a UNet with side loss we create a new nn.Module class which has a ResBlockUNet as member. The ResBlockUNet is configured such that the results of the bottom convolution and all the results of the up-stream convolutions are returned as (side)-output. a 1x1 convolutions is used to give the side outputs the right number of out_channels and UpSampling is used to resize all side-outputs to the full resolution of the input. These side side-predictions are returned by our MySideLossUNet. Furthermore, all side-predictions are concatenated and feed trough another two residual blocks to make the final prediction.

class MySideLossUNet(nn.Module):
    def __init__(self, in_channels, out_channels, depth=3):
        super(MySideLossUNet, self).__init__()

        self.depth = depth
        self.unet = ResBlockUNet(in_channels=in_channels, out_channels=in_channels*2,
                                 dim=2, unet_kwargs=dict(depth=depth),
                                 side_out_parts=['bottom', 'up'])

        # number of out channels
        self.n_channels_per_output = self.unet.n_channels_per_output

        # 1x1 conv to give the side outs of the unet
        # the right number of channels
        # and a Upsampling to give the right shape
        upscale_factor = 2**self.depth
        conv_and_scale = []
        for n_channels in self.n_channels_per_output:

            # conv blocks
            conv = Conv2D(in_channels=n_channels, out_channels=out_channels, kernel_size=1)
            if upscale_factor > 1:
                upsample = nn.Upsample(scale_factor=upscale_factor)
                conv_and_scale.append(nn.Sequential(conv, upsample))
            else:
                conv_and_scale.append(conv)

            upscale_factor //= 2

        self.conv_and_scale = nn.ModuleList(conv_and_scale)


        # combined number of channels after concat
        # concat side output predictions with main output of unet
        self.n_channels_combined = (self.depth + 1)* out_channels + in_channels*2

        self.final_block = nn.Sequential(
            ResBlock(dim=2,in_channels=self.n_channels_combined, out_channels=self.n_channels_combined),
            ResBlock(in_channels=self.n_channels_combined, out_channels=out_channels,
                    dim=2, activated=False),
        )

    def forward(self, input):
        outs = self.unet(input)
        assert len(outs) == len(self.n_channels_per_output)

        # convert the unet output into the right number of
        preds = [None] * len(outs)
        for i,out in enumerate(outs):
            preds[i] = self.conv_and_scale[i](out)

        # this is the side output
        preds =  tuple(preds)

        # concat side output predictions with main output of unet
        combined = torch.cat(preds + (outs[-1],), 1)

        final_res = self.final_block(combined)

        # return everything
        return preds + (final_res,)

We use a custom loss functions which applied CrossEntropyLoss to all side outputs. The side outputs are weighted in a quadratic fashion and added up into a single value

class MySideLoss(nn.Module):
    """Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.
    """

    def __init__(self):
        super(MySideLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(reduce=True)

        w = 1.0
        l = None

    def forward(self, predictions, target):
        w = 1.0
        l = None
        for p in predictions:
            ll = self.criterion(p, target)*w
            if l is None:
                l = ll
            else:
                l += ll
            w *= 2
        return l

Training boilerplate (see Trainer Example)

LOG_DIRECTORY = ensure_dir('log')
SAVE_DIRECTORY = ensure_dir('save')
DATASET_DIRECTORY = ensure_dir('dataset')


USE_CUDA = True

# Build a residual unet where the last layer is not activated
sl_unet = MySideLossUNet(in_channels=5, out_channels=2)

model = nn.Sequential(
    ResBlock(dim=2, in_channels=1, out_channels=5),
    sl_unet
)
train_loader, test_loader, validate_loader = get_binary_blob_loaders(
    train_batch_size=3,
    length=512, # <= size of the images
    gaussian_noise_sigma=1.5 # <= how noise are the images
)

# Build trainer
trainer = Trainer(model)
trainer.build_criterion(MySideLoss())
trainer.build_optimizer('Adam')
trainer.validate_every((10, 'epochs'))
#trainer.save_every((10, 'epochs'))
#trainer.save_to_directory(SAVE_DIRECTORY)
trainer.set_max_num_epochs(40)

# Bind loaders
trainer \
    .bind_loader('train', train_loader)\
    .bind_loader('validate', validate_loader)

if USE_CUDA:
    trainer.cuda()

# Go!
trainer.fit()

Out:

[+][2018-08-15 16:29:37.317833] [PROGRESS] Training iteration 0 (batch 0 of epoch 0).
[+][2018-08-15 16:29:38.308397] [PROGRESS] Training iteration 1 (batch 1 of epoch 0).
[+][2018-08-15 16:29:38.596276] [PROGRESS] Training iteration 2 (batch 2 of epoch 0).
[+][2018-08-15 16:29:38.875276] [PROGRESS] Training iteration 3 (batch 3 of epoch 0).
[+][2018-08-15 16:29:39.140556] [PROGRESS] Training iteration 4 (batch 4 of epoch 0).
[+][2018-08-15 16:29:39.414860] [PROGRESS] Training iteration 5 (batch 5 of epoch 0).
[+][2018-08-15 16:29:39.681612] [PROGRESS] Training iteration 6 (batch 6 of epoch 0).
[+][2018-08-15 16:29:39.880760] [PROGRESS] Training iteration 7 (batch 7 of epoch 0).
[+][2018-08-15 16:29:40.376422] [PROGRESS] Training iteration 8 (batch 1 of epoch 1).
[+][2018-08-15 16:29:40.653050] [PROGRESS] Training iteration 9 (batch 2 of epoch 1).
[+][2018-08-15 16:29:40.928012] [PROGRESS] Training iteration 10 (batch 3 of epoch 1).
[+][2018-08-15 16:29:41.196862] [PROGRESS] Training iteration 11 (batch 4 of epoch 1).
[+][2018-08-15 16:29:41.461534] [PROGRESS] Training iteration 12 (batch 5 of epoch 1).
[+][2018-08-15 16:29:41.733103] [PROGRESS] Training iteration 13 (batch 6 of epoch 1).
[+][2018-08-15 16:29:41.932649] [PROGRESS] Training iteration 14 (batch 7 of epoch 1).
[+][2018-08-15 16:29:42.405599] [PROGRESS] Training iteration 15 (batch 1 of epoch 2).
[+][2018-08-15 16:29:42.685087] [PROGRESS] Training iteration 16 (batch 2 of epoch 2).
[+][2018-08-15 16:29:42.955658] [PROGRESS] Training iteration 17 (batch 3 of epoch 2).
[+][2018-08-15 16:29:43.227472] [PROGRESS] Training iteration 18 (batch 4 of epoch 2).
[+][2018-08-15 16:29:43.495791] [PROGRESS] Training iteration 19 (batch 5 of epoch 2).
[+][2018-08-15 16:29:43.768417] [PROGRESS] Training iteration 20 (batch 6 of epoch 2).
[+][2018-08-15 16:29:43.972916] [PROGRESS] Training iteration 21 (batch 7 of epoch 2).
[+][2018-08-15 16:29:44.448398] [PROGRESS] Training iteration 22 (batch 1 of epoch 3).
[+][2018-08-15 16:29:44.730122] [PROGRESS] Training iteration 23 (batch 2 of epoch 3).
[+][2018-08-15 16:29:45.000676] [PROGRESS] Training iteration 24 (batch 3 of epoch 3).
[+][2018-08-15 16:29:45.276138] [PROGRESS] Training iteration 25 (batch 4 of epoch 3).
[+][2018-08-15 16:29:45.547395] [PROGRESS] Training iteration 26 (batch 5 of epoch 3).
[+][2018-08-15 16:29:45.816132] [PROGRESS] Training iteration 27 (batch 6 of epoch 3).
[+][2018-08-15 16:29:46.022989] [PROGRESS] Training iteration 28 (batch 7 of epoch 3).
[+][2018-08-15 16:29:46.489366] [PROGRESS] Training iteration 29 (batch 1 of epoch 4).
[+][2018-08-15 16:29:46.774053] [PROGRESS] Training iteration 30 (batch 2 of epoch 4).
[+][2018-08-15 16:29:47.048494] [PROGRESS] Training iteration 31 (batch 3 of epoch 4).
[+][2018-08-15 16:29:47.316794] [PROGRESS] Training iteration 32 (batch 4 of epoch 4).
[+][2018-08-15 16:29:47.588553] [PROGRESS] Training iteration 33 (batch 5 of epoch 4).
[+][2018-08-15 16:29:47.857390] [PROGRESS] Training iteration 34 (batch 6 of epoch 4).
[+][2018-08-15 16:29:48.054701] [PROGRESS] Training iteration 35 (batch 7 of epoch 4).
[+][2018-08-15 16:29:48.519348] [PROGRESS] Training iteration 36 (batch 1 of epoch 5).
[+][2018-08-15 16:29:48.798081] [PROGRESS] Training iteration 37 (batch 2 of epoch 5).
[+][2018-08-15 16:29:49.070647] [PROGRESS] Training iteration 38 (batch 3 of epoch 5).
[+][2018-08-15 16:29:49.343103] [PROGRESS] Training iteration 39 (batch 4 of epoch 5).
[+][2018-08-15 16:29:49.620415] [PROGRESS] Training iteration 40 (batch 5 of epoch 5).
[+][2018-08-15 16:29:49.891498] [PROGRESS] Training iteration 41 (batch 6 of epoch 5).
[+][2018-08-15 16:29:50.091783] [PROGRESS] Training iteration 42 (batch 7 of epoch 5).
[+][2018-08-15 16:29:50.554059] [PROGRESS] Training iteration 43 (batch 1 of epoch 6).
[+][2018-08-15 16:29:50.838354] [PROGRESS] Training iteration 44 (batch 2 of epoch 6).
[+][2018-08-15 16:29:51.114530] [PROGRESS] Training iteration 45 (batch 3 of epoch 6).
[+][2018-08-15 16:29:51.385805] [PROGRESS] Training iteration 46 (batch 4 of epoch 6).
[+][2018-08-15 16:29:51.658837] [PROGRESS] Training iteration 47 (batch 5 of epoch 6).
[+][2018-08-15 16:29:51.931726] [PROGRESS] Training iteration 48 (batch 6 of epoch 6).
[+][2018-08-15 16:29:52.132255] [PROGRESS] Training iteration 49 (batch 7 of epoch 6).
[+][2018-08-15 16:29:52.621253] [PROGRESS] Training iteration 50 (batch 1 of epoch 7).
[+][2018-08-15 16:29:52.907941] [PROGRESS] Training iteration 51 (batch 2 of epoch 7).
[+][2018-08-15 16:29:53.193912] [PROGRESS] Training iteration 52 (batch 3 of epoch 7).
[+][2018-08-15 16:29:53.477169] [PROGRESS] Training iteration 53 (batch 4 of epoch 7).
[+][2018-08-15 16:29:53.756949] [PROGRESS] Training iteration 54 (batch 5 of epoch 7).
[+][2018-08-15 16:29:54.031631] [PROGRESS] Training iteration 55 (batch 6 of epoch 7).
[+][2018-08-15 16:29:54.245049] [PROGRESS] Training iteration 56 (batch 7 of epoch 7).
[+][2018-08-15 16:29:54.718408] [PROGRESS] Training iteration 57 (batch 1 of epoch 8).
[+][2018-08-15 16:29:55.006328] [PROGRESS] Training iteration 58 (batch 2 of epoch 8).
[+][2018-08-15 16:29:55.285414] [PROGRESS] Training iteration 59 (batch 3 of epoch 8).
[+][2018-08-15 16:29:55.558634] [PROGRESS] Training iteration 60 (batch 4 of epoch 8).
[+][2018-08-15 16:29:55.827223] [PROGRESS] Training iteration 61 (batch 5 of epoch 8).
[+][2018-08-15 16:29:56.097151] [PROGRESS] Training iteration 62 (batch 6 of epoch 8).
[+][2018-08-15 16:29:56.297308] [PROGRESS] Training iteration 63 (batch 7 of epoch 8).
[+][2018-08-15 16:29:56.772380] [PROGRESS] Training iteration 64 (batch 1 of epoch 9).
[+][2018-08-15 16:29:57.059602] [PROGRESS] Training iteration 65 (batch 2 of epoch 9).
[+][2018-08-15 16:29:57.331440] [PROGRESS] Training iteration 66 (batch 3 of epoch 9).
[+][2018-08-15 16:29:57.606866] [PROGRESS] Training iteration 67 (batch 4 of epoch 9).
[+][2018-08-15 16:29:57.879106] [PROGRESS] Training iteration 68 (batch 5 of epoch 9).
[+][2018-08-15 16:29:58.150403] [PROGRESS] Training iteration 69 (batch 6 of epoch 9).
[+][2018-08-15 16:29:58.366411] [PROGRESS] Training iteration 70 (batch 7 of epoch 9).
[+][2018-08-15 16:29:58.853279] [INFO    ] Breaking to validate.
[+][2018-08-15 16:29:58.853767] [INFO    ] Validating.
[+][2018-08-15 16:29:58.958861] [PROGRESS] Validating iteration 0.
[+][2018-08-15 16:29:59.013858] [PROGRESS] Validating iteration 1.
[+][2018-08-15 16:29:59.071769] [PROGRESS] Validating iteration 2.
[+][2018-08-15 16:29:59.129745] [PROGRESS] Validating iteration 3.
[+][2018-08-15 16:29:59.188124] [PROGRESS] Validating iteration 4.
[+][2018-08-15 16:29:59.247074] [PROGRESS] Validating iteration 5.
[+][2018-08-15 16:29:59.304987] [PROGRESS] Validating iteration 6.
[+][2018-08-15 16:29:59.361782] [PROGRESS] Validating iteration 7.
[+][2018-08-15 16:29:59.418838] [PROGRESS] Validating iteration 8.
[+][2018-08-15 16:29:59.476005] [PROGRESS] Validating iteration 9.
[+][2018-08-15 16:29:59.533727] [PROGRESS] Validating iteration 10.
[+][2018-08-15 16:29:59.594306] [PROGRESS] Validating iteration 11.
[+][2018-08-15 16:29:59.650635] [PROGRESS] Validating iteration 12.
[+][2018-08-15 16:29:59.706891] [PROGRESS] Validating iteration 13.
[+][2018-08-15 16:29:59.763655] [PROGRESS] Validating iteration 14.
[+][2018-08-15 16:29:59.820176] [PROGRESS] Validating iteration 15.
[+][2018-08-15 16:29:59.876687] [PROGRESS] Validating iteration 16.
[+][2018-08-15 16:29:59.933766] [PROGRESS] Validating iteration 17.
[+][2018-08-15 16:29:59.990322] [PROGRESS] Validating iteration 18.
[+][2018-08-15 16:30:00.052946] [PROGRESS] Validating iteration 19.
[+][2018-08-15 16:30:00.077641] [INFO    ] validate generator exhausted, breaking.
[+][2018-08-15 16:30:00.077729] [INFO    ] Done validating. Logging results...
[+][2018-08-15 16:30:00.077791] [INFO    ] Validation loss: 7.947246861457825; validation error: None
[+][2018-08-15 16:30:00.078529] [PROGRESS] Training iteration 0 (batch 1 of epoch 10).
[+][2018-08-15 16:30:00.360171] [PROGRESS] Training iteration 1 (batch 2 of epoch 10).
[+][2018-08-15 16:30:00.644102] [PROGRESS] Training iteration 2 (batch 3 of epoch 10).
[+][2018-08-15 16:30:00.921537] [PROGRESS] Training iteration 3 (batch 4 of epoch 10).
[+][2018-08-15 16:30:01.207870] [PROGRESS] Training iteration 4 (batch 5 of epoch 10).
[+][2018-08-15 16:30:01.476774] [PROGRESS] Training iteration 5 (batch 6 of epoch 10).
[+][2018-08-15 16:30:01.676789] [PROGRESS] Training iteration 6 (batch 7 of epoch 10).
[+][2018-08-15 16:30:02.166358] [PROGRESS] Training iteration 7 (batch 1 of epoch 11).
[+][2018-08-15 16:30:02.466828] [PROGRESS] Training iteration 8 (batch 2 of epoch 11).
[+][2018-08-15 16:30:02.751446] [PROGRESS] Training iteration 9 (batch 3 of epoch 11).
[+][2018-08-15 16:30:03.028844] [PROGRESS] Training iteration 10 (batch 4 of epoch 11).
[+][2018-08-15 16:30:03.311449] [PROGRESS] Training iteration 11 (batch 5 of epoch 11).
[+][2018-08-15 16:30:03.596218] [PROGRESS] Training iteration 12 (batch 6 of epoch 11).
[+][2018-08-15 16:30:03.808256] [PROGRESS] Training iteration 13 (batch 7 of epoch 11).
[+][2018-08-15 16:30:04.312191] [PROGRESS] Training iteration 14 (batch 1 of epoch 12).
[+][2018-08-15 16:30:04.609859] [PROGRESS] Training iteration 15 (batch 2 of epoch 12).
[+][2018-08-15 16:30:04.900997] [PROGRESS] Training iteration 16 (batch 3 of epoch 12).
[+][2018-08-15 16:30:05.174522] [PROGRESS] Training iteration 17 (batch 4 of epoch 12).
[+][2018-08-15 16:30:05.453524] [PROGRESS] Training iteration 18 (batch 5 of epoch 12).
[+][2018-08-15 16:30:05.736041] [PROGRESS] Training iteration 19 (batch 6 of epoch 12).
[+][2018-08-15 16:30:05.943708] [PROGRESS] Training iteration 20 (batch 7 of epoch 12).
[+][2018-08-15 16:30:06.421500] [PROGRESS] Training iteration 21 (batch 1 of epoch 13).
[+][2018-08-15 16:30:06.713911] [PROGRESS] Training iteration 22 (batch 2 of epoch 13).
[+][2018-08-15 16:30:06.998103] [PROGRESS] Training iteration 23 (batch 3 of epoch 13).
[+][2018-08-15 16:30:07.280678] [PROGRESS] Training iteration 24 (batch 4 of epoch 13).
[+][2018-08-15 16:30:07.564899] [PROGRESS] Training iteration 25 (batch 5 of epoch 13).
[+][2018-08-15 16:30:07.850231] [PROGRESS] Training iteration 26 (batch 6 of epoch 13).
[+][2018-08-15 16:30:08.066939] [PROGRESS] Training iteration 27 (batch 7 of epoch 13).
[+][2018-08-15 16:30:08.568652] [PROGRESS] Training iteration 28 (batch 1 of epoch 14).
[+][2018-08-15 16:30:08.856433] [PROGRESS] Training iteration 29 (batch 2 of epoch 14).
[+][2018-08-15 16:30:09.136466] [PROGRESS] Training iteration 30 (batch 3 of epoch 14).
[+][2018-08-15 16:30:09.406506] [PROGRESS] Training iteration 31 (batch 4 of epoch 14).
[+][2018-08-15 16:30:09.678043] [PROGRESS] Training iteration 32 (batch 5 of epoch 14).
[+][2018-08-15 16:30:09.951426] [PROGRESS] Training iteration 33 (batch 6 of epoch 14).
[+][2018-08-15 16:30:10.153718] [PROGRESS] Training iteration 34 (batch 7 of epoch 14).
[+][2018-08-15 16:30:10.636910] [PROGRESS] Training iteration 35 (batch 1 of epoch 15).
[+][2018-08-15 16:30:10.922165] [PROGRESS] Training iteration 36 (batch 2 of epoch 15).
[+][2018-08-15 16:30:11.198678] [PROGRESS] Training iteration 37 (batch 3 of epoch 15).
[+][2018-08-15 16:30:11.490502] [PROGRESS] Training iteration 38 (batch 4 of epoch 15).
[+][2018-08-15 16:30:11.772660] [PROGRESS] Training iteration 39 (batch 5 of epoch 15).
[+][2018-08-15 16:30:12.055936] [PROGRESS] Training iteration 40 (batch 6 of epoch 15).
[+][2018-08-15 16:30:12.259801] [PROGRESS] Training iteration 41 (batch 7 of epoch 15).
[+][2018-08-15 16:30:12.735688] [PROGRESS] Training iteration 42 (batch 1 of epoch 16).
[+][2018-08-15 16:30:13.016669] [PROGRESS] Training iteration 43 (batch 2 of epoch 16).
[+][2018-08-15 16:30:13.298964] [PROGRESS] Training iteration 44 (batch 3 of epoch 16).
[+][2018-08-15 16:30:13.575538] [PROGRESS] Training iteration 45 (batch 4 of epoch 16).
[+][2018-08-15 16:30:13.846381] [PROGRESS] Training iteration 46 (batch 5 of epoch 16).
[+][2018-08-15 16:30:14.120780] [PROGRESS] Training iteration 47 (batch 6 of epoch 16).
[+][2018-08-15 16:30:14.325160] [PROGRESS] Training iteration 48 (batch 7 of epoch 16).
[+][2018-08-15 16:30:14.805172] [PROGRESS] Training iteration 49 (batch 1 of epoch 17).
[+][2018-08-15 16:30:15.092996] [PROGRESS] Training iteration 50 (batch 2 of epoch 17).
[+][2018-08-15 16:30:15.366249] [PROGRESS] Training iteration 51 (batch 3 of epoch 17).
[+][2018-08-15 16:30:15.637710] [PROGRESS] Training iteration 52 (batch 4 of epoch 17).
[+][2018-08-15 16:30:15.914322] [PROGRESS] Training iteration 53 (batch 5 of epoch 17).
[+][2018-08-15 16:30:16.185274] [PROGRESS] Training iteration 54 (batch 6 of epoch 17).
[+][2018-08-15 16:30:16.392048] [PROGRESS] Training iteration 55 (batch 7 of epoch 17).
[+][2018-08-15 16:30:16.884281] [PROGRESS] Training iteration 56 (batch 1 of epoch 18).
[+][2018-08-15 16:30:17.183986] [PROGRESS] Training iteration 57 (batch 2 of epoch 18).
[+][2018-08-15 16:30:17.474442] [PROGRESS] Training iteration 58 (batch 3 of epoch 18).
[+][2018-08-15 16:30:17.760488] [PROGRESS] Training iteration 59 (batch 4 of epoch 18).
[+][2018-08-15 16:30:18.037754] [PROGRESS] Training iteration 60 (batch 5 of epoch 18).
[+][2018-08-15 16:30:18.305494] [PROGRESS] Training iteration 61 (batch 6 of epoch 18).
[+][2018-08-15 16:30:18.518877] [PROGRESS] Training iteration 62 (batch 7 of epoch 18).
[+][2018-08-15 16:30:19.012578] [PROGRESS] Training iteration 63 (batch 1 of epoch 19).
[+][2018-08-15 16:30:19.302636] [PROGRESS] Training iteration 64 (batch 2 of epoch 19).
[+][2018-08-15 16:30:19.577104] [PROGRESS] Training iteration 65 (batch 3 of epoch 19).
[+][2018-08-15 16:30:19.853837] [PROGRESS] Training iteration 66 (batch 4 of epoch 19).
[+][2018-08-15 16:30:20.134190] [PROGRESS] Training iteration 67 (batch 5 of epoch 19).
[+][2018-08-15 16:30:20.418572] [PROGRESS] Training iteration 68 (batch 6 of epoch 19).
[+][2018-08-15 16:30:20.638621] [PROGRESS] Training iteration 69 (batch 7 of epoch 19).
[+][2018-08-15 16:30:21.147010] [INFO    ] Breaking to validate.
[+][2018-08-15 16:30:21.147418] [INFO    ] Validating.
[+][2018-08-15 16:30:21.239198] [PROGRESS] Validating iteration 0.
[+][2018-08-15 16:30:21.300351] [PROGRESS] Validating iteration 1.
[+][2018-08-15 16:30:21.362514] [PROGRESS] Validating iteration 2.
[+][2018-08-15 16:30:21.425657] [PROGRESS] Validating iteration 3.
[+][2018-08-15 16:30:21.485280] [PROGRESS] Validating iteration 4.
[+][2018-08-15 16:30:21.545027] [PROGRESS] Validating iteration 5.
[+][2018-08-15 16:30:21.604092] [PROGRESS] Validating iteration 6.
[+][2018-08-15 16:30:21.662858] [PROGRESS] Validating iteration 7.
[+][2018-08-15 16:30:21.718249] [PROGRESS] Validating iteration 8.
[+][2018-08-15 16:30:21.777795] [PROGRESS] Validating iteration 9.
[+][2018-08-15 16:30:21.838286] [PROGRESS] Validating iteration 10.
[+][2018-08-15 16:30:21.897501] [PROGRESS] Validating iteration 11.
[+][2018-08-15 16:30:21.957506] [PROGRESS] Validating iteration 12.
[+][2018-08-15 16:30:22.016603] [PROGRESS] Validating iteration 13.
[+][2018-08-15 16:30:22.075588] [PROGRESS] Validating iteration 14.
[+][2018-08-15 16:30:22.135487] [PROGRESS] Validating iteration 15.
[+][2018-08-15 16:30:22.193858] [PROGRESS] Validating iteration 16.
[+][2018-08-15 16:30:22.253880] [PROGRESS] Validating iteration 17.
[+][2018-08-15 16:30:22.311913] [PROGRESS] Validating iteration 18.
[+][2018-08-15 16:30:22.371897] [PROGRESS] Validating iteration 19.
[+][2018-08-15 16:30:22.396741] [INFO    ] validate generator exhausted, breaking.
[+][2018-08-15 16:30:22.396829] [INFO    ] Done validating. Logging results...
[+][2018-08-15 16:30:22.396892] [INFO    ] Validation loss: 10.703891134262085; validation error: None
[+][2018-08-15 16:30:22.397610] [PROGRESS] Training iteration 0 (batch 1 of epoch 20).
[+][2018-08-15 16:30:22.675747] [PROGRESS] Training iteration 1 (batch 2 of epoch 20).
[+][2018-08-15 16:30:22.956413] [PROGRESS] Training iteration 2 (batch 3 of epoch 20).
[+][2018-08-15 16:30:23.232252] [PROGRESS] Training iteration 3 (batch 4 of epoch 20).
[+][2018-08-15 16:30:23.498622] [PROGRESS] Training iteration 4 (batch 5 of epoch 20).
[+][2018-08-15 16:30:23.780016] [PROGRESS] Training iteration 5 (batch 6 of epoch 20).
[+][2018-08-15 16:30:23.979127] [PROGRESS] Training iteration 6 (batch 7 of epoch 20).
[+][2018-08-15 16:30:24.457883] [PROGRESS] Training iteration 7 (batch 1 of epoch 21).
[+][2018-08-15 16:30:24.749453] [PROGRESS] Training iteration 8 (batch 2 of epoch 21).
[+][2018-08-15 16:30:25.032240] [PROGRESS] Training iteration 9 (batch 3 of epoch 21).
[+][2018-08-15 16:30:25.308305] [PROGRESS] Training iteration 10 (batch 4 of epoch 21).
[+][2018-08-15 16:30:25.597316] [PROGRESS] Training iteration 11 (batch 5 of epoch 21).
[+][2018-08-15 16:30:25.877275] [PROGRESS] Training iteration 12 (batch 6 of epoch 21).
[+][2018-08-15 16:30:26.078678] [PROGRESS] Training iteration 13 (batch 7 of epoch 21).
[+][2018-08-15 16:30:26.560681] [PROGRESS] Training iteration 14 (batch 1 of epoch 22).
[+][2018-08-15 16:30:26.845325] [PROGRESS] Training iteration 15 (batch 2 of epoch 22).
[+][2018-08-15 16:30:27.124711] [PROGRESS] Training iteration 16 (batch 3 of epoch 22).
[+][2018-08-15 16:30:27.395158] [PROGRESS] Training iteration 17 (batch 4 of epoch 22).
[+][2018-08-15 16:30:27.666084] [PROGRESS] Training iteration 18 (batch 5 of epoch 22).
[+][2018-08-15 16:30:27.941787] [PROGRESS] Training iteration 19 (batch 6 of epoch 22).
[+][2018-08-15 16:30:28.141813] [PROGRESS] Training iteration 20 (batch 7 of epoch 22).
[+][2018-08-15 16:30:28.616893] [PROGRESS] Training iteration 21 (batch 1 of epoch 23).
[+][2018-08-15 16:30:28.898152] [PROGRESS] Training iteration 22 (batch 2 of epoch 23).
[+][2018-08-15 16:30:29.173405] [PROGRESS] Training iteration 23 (batch 3 of epoch 23).
[+][2018-08-15 16:30:29.450420] [PROGRESS] Training iteration 24 (batch 4 of epoch 23).
[+][2018-08-15 16:30:29.723646] [PROGRESS] Training iteration 25 (batch 5 of epoch 23).
[+][2018-08-15 16:30:29.995602] [PROGRESS] Training iteration 26 (batch 6 of epoch 23).
[+][2018-08-15 16:30:30.206753] [PROGRESS] Training iteration 27 (batch 7 of epoch 23).
[+][2018-08-15 16:30:30.704573] [PROGRESS] Training iteration 28 (batch 1 of epoch 24).
[+][2018-08-15 16:30:30.985471] [PROGRESS] Training iteration 29 (batch 2 of epoch 24).
[+][2018-08-15 16:30:31.262106] [PROGRESS] Training iteration 30 (batch 3 of epoch 24).
[+][2018-08-15 16:30:31.536989] [PROGRESS] Training iteration 31 (batch 4 of epoch 24).
[+][2018-08-15 16:30:31.803998] [PROGRESS] Training iteration 32 (batch 5 of epoch 24).
[+][2018-08-15 16:30:32.077281] [PROGRESS] Training iteration 33 (batch 6 of epoch 24).
[+][2018-08-15 16:30:32.272983] [PROGRESS] Training iteration 34 (batch 7 of epoch 24).
[+][2018-08-15 16:30:32.762518] [PROGRESS] Training iteration 35 (batch 1 of epoch 25).
[+][2018-08-15 16:30:33.048994] [PROGRESS] Training iteration 36 (batch 2 of epoch 25).
[+][2018-08-15 16:30:33.324238] [PROGRESS] Training iteration 37 (batch 3 of epoch 25).
[+][2018-08-15 16:30:33.595402] [PROGRESS] Training iteration 38 (batch 4 of epoch 25).
[+][2018-08-15 16:30:33.874961] [PROGRESS] Training iteration 39 (batch 5 of epoch 25).
[+][2018-08-15 16:30:34.151006] [PROGRESS] Training iteration 40 (batch 6 of epoch 25).
[+][2018-08-15 16:30:34.361310] [PROGRESS] Training iteration 41 (batch 7 of epoch 25).
[+][2018-08-15 16:30:34.852101] [PROGRESS] Training iteration 42 (batch 1 of epoch 26).
[+][2018-08-15 16:30:35.141097] [PROGRESS] Training iteration 43 (batch 2 of epoch 26).
[+][2018-08-15 16:30:35.418699] [PROGRESS] Training iteration 44 (batch 3 of epoch 26).
[+][2018-08-15 16:30:35.689636] [PROGRESS] Training iteration 45 (batch 4 of epoch 26).
[+][2018-08-15 16:30:35.965857] [PROGRESS] Training iteration 46 (batch 5 of epoch 26).
[+][2018-08-15 16:30:36.239555] [PROGRESS] Training iteration 47 (batch 6 of epoch 26).
[+][2018-08-15 16:30:36.440579] [PROGRESS] Training iteration 48 (batch 7 of epoch 26).
[+][2018-08-15 16:30:36.915958] [PROGRESS] Training iteration 49 (batch 1 of epoch 27).
[+][2018-08-15 16:30:37.195972] [PROGRESS] Training iteration 50 (batch 2 of epoch 27).
[+][2018-08-15 16:30:37.474195] [PROGRESS] Training iteration 51 (batch 3 of epoch 27).
[+][2018-08-15 16:30:37.751738] [PROGRESS] Training iteration 52 (batch 4 of epoch 27).
[+][2018-08-15 16:30:38.027939] [PROGRESS] Training iteration 53 (batch 5 of epoch 27).
[+][2018-08-15 16:30:38.302136] [PROGRESS] Training iteration 54 (batch 6 of epoch 27).
[+][2018-08-15 16:30:38.505357] [PROGRESS] Training iteration 55 (batch 7 of epoch 27).
[+][2018-08-15 16:30:38.982864] [PROGRESS] Training iteration 56 (batch 1 of epoch 28).
[+][2018-08-15 16:30:39.270670] [PROGRESS] Training iteration 57 (batch 2 of epoch 28).
[+][2018-08-15 16:30:39.549456] [PROGRESS] Training iteration 58 (batch 3 of epoch 28).
[+][2018-08-15 16:30:39.818721] [PROGRESS] Training iteration 59 (batch 4 of epoch 28).
[+][2018-08-15 16:30:40.094805] [PROGRESS] Training iteration 60 (batch 5 of epoch 28).
[+][2018-08-15 16:30:40.366020] [PROGRESS] Training iteration 61 (batch 6 of epoch 28).
[+][2018-08-15 16:30:40.563122] [PROGRESS] Training iteration 62 (batch 7 of epoch 28).
[+][2018-08-15 16:30:41.043610] [PROGRESS] Training iteration 63 (batch 1 of epoch 29).
[+][2018-08-15 16:30:41.323521] [PROGRESS] Training iteration 64 (batch 2 of epoch 29).
[+][2018-08-15 16:30:41.600251] [PROGRESS] Training iteration 65 (batch 3 of epoch 29).
[+][2018-08-15 16:30:41.877148] [PROGRESS] Training iteration 66 (batch 4 of epoch 29).
[+][2018-08-15 16:30:42.147088] [PROGRESS] Training iteration 67 (batch 5 of epoch 29).
[+][2018-08-15 16:30:42.419988] [PROGRESS] Training iteration 68 (batch 6 of epoch 29).
[+][2018-08-15 16:30:42.618749] [PROGRESS] Training iteration 69 (batch 7 of epoch 29).
[+][2018-08-15 16:30:43.092523] [INFO    ] Breaking to validate.
[+][2018-08-15 16:30:43.092905] [INFO    ] Validating.
[+][2018-08-15 16:30:43.177080] [PROGRESS] Validating iteration 0.
[+][2018-08-15 16:30:43.233738] [PROGRESS] Validating iteration 1.
[+][2018-08-15 16:30:43.293254] [PROGRESS] Validating iteration 2.
[+][2018-08-15 16:30:43.351630] [PROGRESS] Validating iteration 3.
[+][2018-08-15 16:30:43.411640] [PROGRESS] Validating iteration 4.
[+][2018-08-15 16:30:43.470835] [PROGRESS] Validating iteration 5.
[+][2018-08-15 16:30:43.530343] [PROGRESS] Validating iteration 6.
[+][2018-08-15 16:30:43.589593] [PROGRESS] Validating iteration 7.
[+][2018-08-15 16:30:43.648099] [PROGRESS] Validating iteration 8.
[+][2018-08-15 16:30:43.707435] [PROGRESS] Validating iteration 9.
[+][2018-08-15 16:30:43.765853] [PROGRESS] Validating iteration 10.
[+][2018-08-15 16:30:43.824564] [PROGRESS] Validating iteration 11.
[+][2018-08-15 16:30:43.883227] [PROGRESS] Validating iteration 12.
[+][2018-08-15 16:30:43.942550] [PROGRESS] Validating iteration 13.
[+][2018-08-15 16:30:44.001409] [PROGRESS] Validating iteration 14.
[+][2018-08-15 16:30:44.060214] [PROGRESS] Validating iteration 15.
[+][2018-08-15 16:30:44.118989] [PROGRESS] Validating iteration 16.
[+][2018-08-15 16:30:44.178665] [PROGRESS] Validating iteration 17.
[+][2018-08-15 16:30:44.237520] [PROGRESS] Validating iteration 18.
[+][2018-08-15 16:30:44.297032] [PROGRESS] Validating iteration 19.
[+][2018-08-15 16:30:44.321767] [INFO    ] validate generator exhausted, breaking.
[+][2018-08-15 16:30:44.321856] [INFO    ] Done validating. Logging results...
[+][2018-08-15 16:30:44.321918] [INFO    ] Validation loss: 6.559897494316101; validation error: None
[+][2018-08-15 16:30:44.322665] [PROGRESS] Training iteration 0 (batch 1 of epoch 30).
[+][2018-08-15 16:30:44.603570] [PROGRESS] Training iteration 1 (batch 2 of epoch 30).
[+][2018-08-15 16:30:44.883641] [PROGRESS] Training iteration 2 (batch 3 of epoch 30).
[+][2018-08-15 16:30:45.154331] [PROGRESS] Training iteration 3 (batch 4 of epoch 30).
[+][2018-08-15 16:30:45.430701] [PROGRESS] Training iteration 4 (batch 5 of epoch 30).
[+][2018-08-15 16:30:45.701694] [PROGRESS] Training iteration 5 (batch 6 of epoch 30).
[+][2018-08-15 16:30:45.896465] [PROGRESS] Training iteration 6 (batch 7 of epoch 30).
[+][2018-08-15 16:30:46.384790] [PROGRESS] Training iteration 7 (batch 1 of epoch 31).
[+][2018-08-15 16:30:46.672287] [PROGRESS] Training iteration 8 (batch 2 of epoch 31).
[+][2018-08-15 16:30:46.954580] [PROGRESS] Training iteration 9 (batch 3 of epoch 31).
[+][2018-08-15 16:30:47.227587] [PROGRESS] Training iteration 10 (batch 4 of epoch 31).
[+][2018-08-15 16:30:47.496988] [PROGRESS] Training iteration 11 (batch 5 of epoch 31).
[+][2018-08-15 16:30:47.772454] [PROGRESS] Training iteration 12 (batch 6 of epoch 31).
[+][2018-08-15 16:30:47.972633] [PROGRESS] Training iteration 13 (batch 7 of epoch 31).
[+][2018-08-15 16:30:48.457489] [PROGRESS] Training iteration 14 (batch 1 of epoch 32).
[+][2018-08-15 16:30:48.745144] [PROGRESS] Training iteration 15 (batch 2 of epoch 32).
[+][2018-08-15 16:30:49.036661] [PROGRESS] Training iteration 16 (batch 3 of epoch 32).
[+][2018-08-15 16:30:49.313462] [PROGRESS] Training iteration 17 (batch 4 of epoch 32).
[+][2018-08-15 16:30:49.609893] [PROGRESS] Training iteration 18 (batch 5 of epoch 32).
[+][2018-08-15 16:30:49.893871] [PROGRESS] Training iteration 19 (batch 6 of epoch 32).
[+][2018-08-15 16:30:50.096663] [PROGRESS] Training iteration 20 (batch 7 of epoch 32).
[+][2018-08-15 16:30:50.580131] [PROGRESS] Training iteration 21 (batch 1 of epoch 33).
[+][2018-08-15 16:30:50.868954] [PROGRESS] Training iteration 22 (batch 2 of epoch 33).
[+][2018-08-15 16:30:51.165526] [PROGRESS] Training iteration 23 (batch 3 of epoch 33).
[+][2018-08-15 16:30:51.457654] [PROGRESS] Training iteration 24 (batch 4 of epoch 33).
[+][2018-08-15 16:30:51.730160] [PROGRESS] Training iteration 25 (batch 5 of epoch 33).
[+][2018-08-15 16:30:52.003223] [PROGRESS] Training iteration 26 (batch 6 of epoch 33).
[+][2018-08-15 16:30:52.204919] [PROGRESS] Training iteration 27 (batch 7 of epoch 33).
[+][2018-08-15 16:30:52.688028] [PROGRESS] Training iteration 28 (batch 1 of epoch 34).
[+][2018-08-15 16:30:52.983445] [PROGRESS] Training iteration 29 (batch 2 of epoch 34).
[+][2018-08-15 16:30:53.264524] [PROGRESS] Training iteration 30 (batch 3 of epoch 34).
[+][2018-08-15 16:30:53.548673] [PROGRESS] Training iteration 31 (batch 4 of epoch 34).
[+][2018-08-15 16:30:53.825932] [PROGRESS] Training iteration 32 (batch 5 of epoch 34).
[+][2018-08-15 16:30:54.106304] [PROGRESS] Training iteration 33 (batch 6 of epoch 34).
[+][2018-08-15 16:30:54.318925] [PROGRESS] Training iteration 34 (batch 7 of epoch 34).
[+][2018-08-15 16:30:54.830594] [PROGRESS] Training iteration 35 (batch 1 of epoch 35).
[+][2018-08-15 16:30:55.144652] [PROGRESS] Training iteration 36 (batch 2 of epoch 35).
[+][2018-08-15 16:30:55.444593] [PROGRESS] Training iteration 37 (batch 3 of epoch 35).
[+][2018-08-15 16:30:55.736421] [PROGRESS] Training iteration 38 (batch 4 of epoch 35).
[+][2018-08-15 16:30:56.011066] [PROGRESS] Training iteration 39 (batch 5 of epoch 35).
[+][2018-08-15 16:30:56.291601] [PROGRESS] Training iteration 40 (batch 6 of epoch 35).
[+][2018-08-15 16:30:56.501087] [PROGRESS] Training iteration 41 (batch 7 of epoch 35).
[+][2018-08-15 16:30:57.036513] [PROGRESS] Training iteration 42 (batch 1 of epoch 36).
[+][2018-08-15 16:30:57.348684] [PROGRESS] Training iteration 43 (batch 2 of epoch 36).
[+][2018-08-15 16:30:57.653222] [PROGRESS] Training iteration 44 (batch 3 of epoch 36).
[+][2018-08-15 16:30:57.941635] [PROGRESS] Training iteration 45 (batch 4 of epoch 36).
[+][2018-08-15 16:30:58.225886] [PROGRESS] Training iteration 46 (batch 5 of epoch 36).
[+][2018-08-15 16:30:58.517621] [PROGRESS] Training iteration 47 (batch 6 of epoch 36).
[+][2018-08-15 16:30:58.727220] [PROGRESS] Training iteration 48 (batch 7 of epoch 36).
[+][2018-08-15 16:30:59.278551] [PROGRESS] Training iteration 49 (batch 1 of epoch 37).
[+][2018-08-15 16:30:59.589882] [PROGRESS] Training iteration 50 (batch 2 of epoch 37).
[+][2018-08-15 16:30:59.885314] [PROGRESS] Training iteration 51 (batch 3 of epoch 37).
[+][2018-08-15 16:31:00.178470] [PROGRESS] Training iteration 52 (batch 4 of epoch 37).
[+][2018-08-15 16:31:00.467785] [PROGRESS] Training iteration 53 (batch 5 of epoch 37).
[+][2018-08-15 16:31:00.751458] [PROGRESS] Training iteration 54 (batch 6 of epoch 37).
[+][2018-08-15 16:31:00.964319] [PROGRESS] Training iteration 55 (batch 7 of epoch 37).
[+][2018-08-15 16:31:01.472615] [PROGRESS] Training iteration 56 (batch 1 of epoch 38).
[+][2018-08-15 16:31:01.778460] [PROGRESS] Training iteration 57 (batch 2 of epoch 38).
[+][2018-08-15 16:31:02.071263] [PROGRESS] Training iteration 58 (batch 3 of epoch 38).
[+][2018-08-15 16:31:02.360600] [PROGRESS] Training iteration 59 (batch 4 of epoch 38).
[+][2018-08-15 16:31:02.641660] [PROGRESS] Training iteration 60 (batch 5 of epoch 38).
[+][2018-08-15 16:31:02.927147] [PROGRESS] Training iteration 61 (batch 6 of epoch 38).
[+][2018-08-15 16:31:03.142380] [PROGRESS] Training iteration 62 (batch 7 of epoch 38).
[+][2018-08-15 16:31:03.645700] [PROGRESS] Training iteration 63 (batch 1 of epoch 39).
[+][2018-08-15 16:31:03.948420] [PROGRESS] Training iteration 64 (batch 2 of epoch 39).
[+][2018-08-15 16:31:04.241279] [PROGRESS] Training iteration 65 (batch 3 of epoch 39).
[+][2018-08-15 16:31:04.532734] [PROGRESS] Training iteration 66 (batch 4 of epoch 39).
[+][2018-08-15 16:31:04.831638] [PROGRESS] Training iteration 67 (batch 5 of epoch 39).
[+][2018-08-15 16:31:05.115673] [PROGRESS] Training iteration 68 (batch 6 of epoch 39).
[+][2018-08-15 16:31:05.325786] [PROGRESS] Training iteration 69 (batch 7 of epoch 39).
[+][2018-08-15 16:31:05.852046] [INFO    ] Breaking to validate.
[+][2018-08-15 16:31:05.852484] [INFO    ] Validating.
[+][2018-08-15 16:31:05.956212] [PROGRESS] Validating iteration 0.
[+][2018-08-15 16:31:06.013258] [PROGRESS] Validating iteration 1.
[+][2018-08-15 16:31:06.067717] [PROGRESS] Validating iteration 2.
[+][2018-08-15 16:31:06.122801] [PROGRESS] Validating iteration 3.
[+][2018-08-15 16:31:06.176166] [PROGRESS] Validating iteration 4.
[+][2018-08-15 16:31:06.230145] [PROGRESS] Validating iteration 5.
[+][2018-08-15 16:31:06.283495] [PROGRESS] Validating iteration 6.
[+][2018-08-15 16:31:06.336984] [PROGRESS] Validating iteration 7.
[+][2018-08-15 16:31:06.395517] [PROGRESS] Validating iteration 8.
[+][2018-08-15 16:31:06.453625] [PROGRESS] Validating iteration 9.
[+][2018-08-15 16:31:06.508579] [PROGRESS] Validating iteration 10.
[+][2018-08-15 16:31:06.561314] [PROGRESS] Validating iteration 11.
[+][2018-08-15 16:31:06.612523] [PROGRESS] Validating iteration 12.
[+][2018-08-15 16:31:06.665358] [PROGRESS] Validating iteration 13.
[+][2018-08-15 16:31:06.717080] [PROGRESS] Validating iteration 14.
[+][2018-08-15 16:31:06.769893] [PROGRESS] Validating iteration 15.
[+][2018-08-15 16:31:06.822899] [PROGRESS] Validating iteration 16.
[+][2018-08-15 16:31:06.876488] [PROGRESS] Validating iteration 17.
[+][2018-08-15 16:31:06.927405] [PROGRESS] Validating iteration 18.
[+][2018-08-15 16:31:06.980757] [PROGRESS] Validating iteration 19.
[+][2018-08-15 16:31:07.005419] [INFO    ] validate generator exhausted, breaking.
[+][2018-08-15 16:31:07.005523] [INFO    ] Done validating. Logging results...
[+][2018-08-15 16:31:07.005614] [INFO    ] Validation loss: 6.0453955173492435; validation error: None
[+][2018-08-15 16:31:07.005700] [INFO    ] Exceeded max number of iterations / epochs, breaking.

Predict with the trained network and visualize the results

# predict:
#trainer.load(best=True)
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
trainer.eval_mode()

if USE_CUDA:
    trainer.cuda()

# look at an example
for img,target in test_loader:
    if USE_CUDA:
        img = img.cuda()

    # softmax on each of the prediction
    preds = trainer.apply_model(img)
    preds = [nn.functional.softmax(pred,dim=1)        for pred in preds]
    preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds]
    img    = unwrap(img,  as_numpy=True, to_cpu=True)
    target  = unwrap(target, as_numpy=True, to_cpu=True)

    n_plots = len(preds) + 2
    batch_size = preds[0].shape[0]

    for b in range(batch_size):

        fig = pylab.figure()

        ax1 = fig.add_subplot(2,4,1)
        ax1.set_title('image')
        ax1.imshow(img[b,0,...])

        ax2 = fig.add_subplot(2,4,2)
        ax2.set_title('ground truth')
        ax2.imshow(target[b,...])

        for i,pred in enumerate(preds):
            axn = fig.add_subplot(2,4, 3+i)
            axn.imshow(pred[b,1,...])

            if i + 1 < len(preds):
                axn.set_title('side prediction %d'%i)
            else:
                axn.set_title('combined prediction')

        pylab.show()

    break
../_images/sphx_glr_plot_train_side_loss_unet_001.png

Total running time of the script: ( 1 minutes 31.868 seconds)

Gallery generated by Sphinx-Gallery