本文共 11390 字,大约阅读时间需要 37 分钟。
下面是全部的代码:
import osimport torchimport numpy as npimport scipy.misc as mfrom PIL import Imagefrom torch.utils import datafrom dataloaders.utils import recursive_glob, decode_segmapfrom mypath import Pathclass CityscapesSegmentation(data.Dataset): def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None): self.root = root self.split = split self.transform = transform self.files = {} self.n_classes = 19 self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) self.annotations_base = os.path.join(self.root, 'gtFine', self.split) self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png') self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16 self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19 self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \ 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \ 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 'motorcycle', 'bicycle'] # 20 self.ignore_index = 255 self.class_map = dict(zip(self.valid_classes, range(self.n_classes))) if not self.files[split]: raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) print("Found %d %s images" % (len(self.files[split]), split)) def __len__(self): return len(self.files[self.split]) def __getitem__(self, index): img_path = self.files[self.split][index].rstrip() lbl_path = os.path.join(self.annotations_base, img_path.split(os.sep)[-2], # os.sep=='/' get city name os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') _img = Image.open(img_path).convert('RGB') _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) _tmp = self.encode_segmap(_tmp) _target = Image.fromarray(_tmp) sample = {'image': _img, 'label': _target} if self.transform: # to do Data transformation or Data enhancement and convert torch sample = self.transform(sample) return sample def encode_segmap(self, mask): # to change original image pixel value to 0-18 and 255 according class id # Put all void classes to zero for _voidc in self.void_classes: mask[mask == _voidc] = self.ignore_index # no need class and unto set 255 (white) for _validc in self.valid_classes: mask[mask == _validc] = self.class_map[_validc] # 19 classes encode from 0 to 18 return maskif __name__ == '__main__': from dataloaders import custom_transforms as tr from dataloaders.utils import decode_segmap from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt # to show image composed_transforms_tr = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScale((0.5, 0.75)), tr.RandomCrop((512, 1024)), tr.RandomRotate(5), tr.ToTensor()]) cityscapes_train = CityscapesSegmentation(split='train', transform=composed_transforms_tr) dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) for ii, sample in enumerate(dataloader): for jj in range(sample["image"].size()[0]): img = sample['image'].numpy() # from torch convert to numpy n x c x h x w gt = sample['label'].numpy() # from torch convert to numpy n x c x h x w tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=c x h x w tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w segmap = decode_segmap(tmp, dataset='cityscapes') img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x c plt.figure() plt.title('display') plt.subplot(211) plt.imshow(img_tmp) plt.subplot(212) plt.imshow(segmap) if ii == 1: break plt.show(block=True)
下面怎么读取图片的 可以参考:
self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
转换的为:
composed_transforms_tr = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScale((0.5, 0.75)), tr.RandomCrop((512, 1024)), tr.RandomRotate(5), tr.ToTensor()])
上面关于图像变换或者说增强的实现代码如下:
上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)
class RandomHorizontalFlip(object): def __call__(self, sample): img = sample['image'] mask = sample['label'] if random.random() < 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT) mask = mask.transpose(Image.FLIP_LEFT_RIGHT) return {'image': img, 'label': mask}class RandomScale(object): def __init__(self, limit): self.limit = limit def __call__(self, sample): img = sample['image'] mask = sample['label'] assert img.size == mask.size scale = random.uniform(self.limit[0], self.limit[1]) w = int(scale * img.size[0]) h = int(scale * img.size[1]) img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) return {'image': img, 'label': mask}class RandomCrop(object): def __init__(self, size, padding=0): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size # h, w self.padding = padding def __call__(self, sample): img, mask = sample['image'], sample['label'] if self.padding > 0: img = ImageOps.expand(img, border=self.padding, fill=0) mask = ImageOps.expand(mask, border=self.padding, fill=0) assert img.size == mask.size w, h = img.size th, tw = self.size # target size if w == tw and h == th: return {'image': img, 'label': mask} if w < tw or h < th: img = img.resize((tw, th), Image.BILINEAR) mask = mask.resize((tw, th), Image.NEAREST) return {'image': img, 'label': mask} x1 = random.randint(0, w - tw) y1 = random.randint(0, h - th) img = img.crop((x1, y1, x1 + tw, y1 + th)) mask = mask.crop((x1, y1, x1 + tw, y1 + th)) return {'image': img, 'label': mask}class RandomRotate(object): def __init__(self, degree): self.degree = degree def __call__(self, sample): img = sample['image'] mask = sample['label'] rotate_degree = random.random() * 2 * self.degree - self.degree img = img.rotate(rotate_degree, Image.BILINEAR) mask = mask.rotate(rotate_degree, Image.NEAREST) return {'image': img, 'label': mask}class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): # swap color axis because # numpy image: H x W x C # torch image: C X H X W img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1)) mask[mask == 255] = 0 # img = torch.from_numpy(img).float() mask = torch.from_numpy(mask).float() return {'image': img, 'label': mask}
直到第五个也就是最后一个(ToTensor函数)变化,对原图首先从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w x c )到(c x h x w),最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将原图转变为一个tensor可以输入后面的深度学习网络中了。
与此相对的标签图也是从
PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w )增加一维得到(h x w x 1)接着调整维度到(1 x h x w),然后mask里面的数值进行处理:255.值大小的全部被重置为0,所以mask里面的值现在只有0-18这些数字了;最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将标签图转变为一个tensor可以输入后面的深度学习网络中了。
对上面的两个tensor的重新变成图像的代码如下:
for ii, sample in enumerate(dataloader): for jj in range(sample["image"].size()[0]): img = sample['image'].numpy() # from torch convert to numpy n x 3 x h x w gt = sample['label'].numpy() # from torch convert to numpy n x 1 x h x w tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=1 x h x w tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=h x w; or tmp.shape=c x h x w dimension-reduction segmap = decode_segmap(tmp, dataset='cityscapes') img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x 3 plt.figure() plt.title('display') plt.subplot(211) plt.imshow(img_tmp) plt.subplot(212) plt.imshow(segmap) if ii == 1: break plt.show(block=True)
里面的标签图(h x w)解码代码如下:
只要是同一类的就给相应的RGB数值,然后整合三张图到一张图里面
segmap = decode_segmap(tmp, dataset='cityscapes') # tmp.shape=h x wdef decode_segmap(label_mask, dataset, plot=False): """Decode segmentation class labels into a color image Args: label_mask (np.ndarray): an (M,N) array of integer values denoting the class label at each spatial location. plot (bool, optional): whether to show the resulting color image in a figure. Returns: (np.ndarray, optional): the resulting decoded color image. """ if dataset == 'pascal': n_classes = 21 label_colours = get_pascal_labels() elif dataset == 'cityscapes': n_classes = 19 label_colours = get_cityscapes_labels() else: raise NotImplementedError r = label_mask.copy() # h x w g = label_mask.copy() # h x w b = label_mask.copy() # h x w for ll in range(0, n_classes): r[label_mask == ll] = label_colours[ll, 0] g[label_mask == ll] = label_colours[ll, 1] b[label_mask == ll] = label_colours[ll, 2] rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # h x w x 3初始化 rgb[:, :, 0] = r / 255.0 rgb[:, :, 1] = g / 255.0 rgb[:, :, 2] = b / 255.0 if plot: plt.imshow(rgb) plt.show() else: return rgb
下面就是label_colours的和类别对应色彩代码详情可以看cityscapes的标签颜色对照表:
def get_cityscapes_labels(): return np.array([ # [ 0, 0, 0], [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152], [0, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])def get_pascal_labels(): """Load the mapping that associates pascal classes with label colors Returns: np.ndarray with dimensions (21, 3) """ return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]])
转载地址:https://blog.csdn.net/zz2230633069/article/details/84668984 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!