Source code for inferno.io.box.camvid

# Adapted from felixgwu's PR here:
# https://github.com/felixgwu/vision/blob/cf491d301f62ae9c77ff7250fb7def5cd55ec963/torchvision/datasets/camvid.py

import os
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
from torchvision.datasets.folder import default_loader
from ...utils.exceptions import assert_
from ..transform.base import Compose
from ..transform.generic import Normalize, NormalizeRange, Cast, AsTorchBatch, Label2OneHot
from ..transform.image import \
    RandomSizedCrop, RandomGammaCorrection, RandomFlip, Scale, PILImage2NumPyArray

try:
    from torchvision.datasets.folder import is_image_file
except ImportError:
    from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension


    def is_image_file(filename):
        return has_file_allowed_extension(filename, IMG_EXTENSIONS)

CAMVID_CLASSES = ['Sky',
                  'Building',
                  'Column-Pole',
                  'Road',
                  'Sidewalk',
                  'Tree',
                  'Sign-Symbol',
                  'Fence',
                  'Car',
                  'Pedestrain',
                  'Bicyclist',
                  'Void']

# weights when using median frequency balancing used in SegNet paper
# https://arxiv.org/pdf/1511.00561.pdf
# The numbers were generated by:
# https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua
CAMVID_CLASS_WEIGHTS = [0.58872014284134,
                        0.51052379608154,
                        2.6966278553009,
                        0.45021694898605,
                        1.1785038709641,
                        0.77028578519821,
                        2.4782588481903,
                        2.5273461341858,
                        1.0122526884079,
                        3.2375309467316,
                        4.1312313079834,
                        0]
# mean and std
CAMVID_MEAN = [0.41189489566336, 0.4251328133025, 0.4326707089857]
CAMVID_STD = [0.27413549931506, 0.28506257482912, 0.28284674400252]

CAMVID_CLASS_COLORS = [
    (128, 128, 128),
    (128, 0, 0),
    (192, 192, 128),
    (128, 64, 128),
    (0, 0, 192),
    (128, 128, 0),
    (192, 128, 128),
    (64, 64, 128),
    (64, 0, 128),
    (64, 64, 0),
    (0, 128, 192),
    (0, 0, 0),
]


[docs]def make_dataset(dir): images = [] for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) item = path images.append(item) return images
[docs]def label_to_long_tensor(pic): label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) label = label.view(pic.size[1], pic.size[0], 1) label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long() return label
[docs]def label_to_pil_image(label): label = label.unsqueeze(0) colored_label = torch.zeros(3, label.size(1), label.size(2)).byte() for i, color in enumerate(CAMVID_CLASS_COLORS): mask = label.eq(i) for j in range(3): colored_label[j].masked_fill_(mask, color[j]) npimg = colored_label.numpy() npimg = np.transpose(npimg, (1, 2, 0)) mode = None if npimg.shape[2] == 1: npimg = npimg[:, :, 0] mode = "L" return Image.fromarray(npimg, mode=mode)
[docs]class CamVid(data.Dataset): SPLIT_NAME_MAPPING = {'train': 'train', 'training': 'train', 'validate': 'val', 'val': 'val', 'validation': 'val', 'test': 'test', 'testing': 'test'} # Dataset statistics CLASS_WEIGHTS = CAMVID_CLASS_WEIGHTS CLASSES = CAMVID_CLASSES MEAN = CAMVID_MEAN STD = CAMVID_STD def __init__(self, root, split='train', image_transform=None, label_transform=None, joint_transform=None, download=False, loader=default_loader): # Validate assert_(split in self.SPLIT_NAME_MAPPING.keys(), "`split` must be one of {}".format(set(self.SPLIT_NAME_MAPPING.keys())), KeyError) # Root directory and split self.root_directory = root self.split = self.SPLIT_NAME_MAPPING.get(split) # Utils self.image_loader = loader # Transforms self.image_transform = image_transform self.label_transform = label_transform self.joint_transform = joint_transform # For when we implement download: if download: self.download() # Make dataset with paths to the image self.image_paths = make_dataset(os.path.join(self.root_directory, self.split)) def __getitem__(self, index): path = self.image_paths[index] image = self.image_loader(path) label = Image.open(path.replace(self.split, self.split + 'annot')) # Apply transforms if self.image_transform is not None: image = self.image_transform(image) if self.label_transform is not None: label = self.label_transform(label) if self.joint_transform is not None: image, label = self.joint_transform(image, label) return image, label def __len__(self): return len(self.image_paths)
[docs] def download(self): # TODO: please download the dataset from # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid raise NotImplementedError
# noinspection PyTypeChecker
[docs]def get_camvid_loaders(root_directory, image_shape=(360, 480), labels_as_onehot=False, train_batch_size=1, validate_batch_size=1, test_batch_size=1, num_workers=2): # Make transforms image_transforms = Compose(PILImage2NumPyArray(), NormalizeRange(), RandomGammaCorrection(), Normalize(mean=CAMVID_MEAN, std=CAMVID_STD)) label_transforms = PILImage2NumPyArray() joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0), preserve_aspect_ratio=True), # Scale raw image back to the original shape Scale(output_image_shape=image_shape, interpolation_order=3, apply_to=[0]), # Scale segmentation back to the original shape # (without interpolation) Scale(output_image_shape=image_shape, interpolation_order=0, apply_to=[1]), RandomFlip(allow_ud_flips=False), # Cast raw image to float Cast('float', apply_to=[0])) if labels_as_onehot: # See cityscapes loader to understand why this is here. joint_transforms\ .add(Label2OneHot(num_classes=len(CAMVID_CLASS_WEIGHTS), dtype='bool', apply_to=[1]))\ .add(Cast('float', apply_to=[1])) else: # Cast label image to long joint_transforms.add(Cast('long', apply_to=[1])) # Batchify joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False)) # Build datasets train_dataset = CamVid(root_directory, split='train', image_transform=image_transforms, label_transform=label_transforms, joint_transform=joint_transforms) validate_dataset = CamVid(root_directory, split='validate', image_transform=image_transforms, label_transform=label_transforms, joint_transform=joint_transforms) test_dataset = CamVid(root_directory, split='test', image_transform=image_transforms, label_transform=label_transforms, joint_transform=joint_transforms) # Build loaders train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) return train_loader, validate_loader, test_loader