自定义数据集-Pokenom Go_完整项目_CodingPark编程公园
发布日期:2021-06-29 15:45:57 浏览次数:2 分类:技术文章

本文共 13256 字,大约阅读时间需要 44 分钟。

在这里插入图片描述

import visdom # 开外网 1:pip install visdom 2:sudo python -m visdom.server ✅

visdom 是可视化web视图

import os, glob ✅

os os 模块在运维工作中是很常用的一个模块。通过os模块调用系统命令。os模块可以跨平台使用。

1.os.listdir():返回输入路径下的文件和列表名称
经过os.listdir()操作之后,pokemon文件夹下的所有文件全部被读取作为list的元素,但是如果是文件夹,那么只读取文件夹的名字,但如果是文件则会读取文件的名字并带属性后缀(.py/.csv/.word/etc)
2.os.path.join():拼接待操作对象
经过.os.path.join()操作之后,会将两个路径进行拼接
3.os.path.isdir():判断输入路径是否为目录或文件夹
一般而言这个操作会和os.path.join()结合起来使用。先进行拼接再进行判断。在自建数据集的时候这么用比较好哦。返回 true 或 False
链接: .

glob

glob是python自带的一个文件操作相关模块,用它可以查找符合自己目的文件,类死于Windows下的文件搜索,支持通配符操作,有“”、“?”、“[]”这三个通配符,“”:代表0个或者多个字符;“?”:代表一个字符;“[]”:匹配指定范围内的字符,如[0-9]匹配数字;主要有以下2个主要方法。
1.glob方法:
glob模块的主要方法就是glob,该方法返回所有匹配的文件路径列表(list);该方法需要一个参数用来制定匹配的路径字符串(字符串可以为绝对路径也可以为相对路径),其返回文件名只包括当前目录里的文件名,不包括子文件夹里的文件。
比如:
import glob
glob.glob(’
.txt’) #这里就是获取此文件的路径下所有的txt文件并返回一个list。如QQ.txt、44.txt
glob.glob(‘glob_?.png’) #这里就是获取路径下所有的 glob_().png文件并返回一个list,如:glob_1.png\glob_q.png
glob.glob(‘glob_[0-9].png’) #这里就是获取次路径下下划线后面数字是-0-9的文件并返回为一个list
glob.glob('glob_[0-9].’) #这里就是获取路径下所有文件名为glob_(0-9范围内)的所有文件

random

random() 方法返回随机生成的一个实数,它在[0,1)范围内。

CSV

csv文件格式是一种通用的电子表格和数据库导入导出格式

————————pokemon.py—自定义数据集文件————————

import torchimport os, globimport random, csvfrom torch.utils.data  import Dataset, DataLoader  #所有自定义数据集的母类from torchvision import transformsfrom PIL import Imageclass Pokemon(Dataset):    def __init__(self, root, resize, mode):     # root 资源地址位置 / resize 图片输出尺寸 / mode 包含train、validation、test        super(Pokemon, self).__init__()        self.root = root        self.resize = resize# 
<拿到名称及编号>
self.name2label = {
} # "sq.......":0 for name in sorted(os.listdir(os.path.join(root))): # 遍历根目录下所有东西 ; listdir返回顺序不固定,用sorted排序 if not os.path.isdir(os.path.join(root,name)): # 因为会进来所有目录和对应文件,所以首先我们要过滤掉文件 continue # 构建字典,名字:0~4数字 self.name2label[name] = len(self.name2label.keys()) # eg: {'squirtle': 4, 'bulbasaur': 0, 'pikachu': 3, 'mewtwo': 2, 'charmander': 1} # print(self.name2label)#
#
self.images, self.labels = self.load_csv('images.csv')#
if mode == 'train':# 60% self.images = self.images[:int(0.6*len(self.images))] self.labels = self.labels[:int(0.6*len(self.labels))] elif mode == 'val':# 20% self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))] else:# 20% self.images = self.images[int(0.8 * len(self.images)):] self.labels = self.labels[int(0.8 * len(self.labels)):]#
def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)): images = [] for name in self.name2label.keys(): # 'pokemon\\mewtwo\\00001.png images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) # 1167, 'pokemon\\bulbasaur\\00000000.png' print(len(images), images) random.shuffle(images) with open(os.path.join(self.root, filename), mode='w', newline='') as f: writer = csv.writer(f) for img in images: # 'pokemon\\bulbasaur\\00000000.png' name = img.split(os.sep)[-2] label = self.name2label[name] # 'pokemon\\bulbasaur\\00000000.png', 0 writer.writerow([img, label]) print('writen into csv file:', filename) # read from csv file images, labels = [], [] with open(os.path.join(self.root, filename)) as f: reader = csv.reader(f) for row in reader: # 'pokemon\\bulbasaur\\00000000.png', 0 img, label = row label = int(label) images.append(img) labels.append(label) assert len(images) == len(labels) return images, labels #
def __len__(self): #返 回数字 return len(self.images) def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # x_hat = (x-mean)/std # x = x_hat*std = mean # x: [c, h, w] # mean: [3] => [3, 1, 1] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) # print(mean.shape, std.shape) x = x_hat * std + mean return x def __getitem__(self, idx): # idx~[0~len(images)] # self.images, self.labels # img: 'pokemon\\bulbasaur\\00000000.png' # label: 0 img, label = self.images[idx], self.labels[idx] tf = transforms.Compose([ lambda x: Image.open(x).convert('RGB'), # string path= > image data transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = tf(img) label = torch.tensor(label) return img, labeldef main(): import visdom # 开外网 1:pip install visdom 2:sudo python -m visdom.server ✅ import time import torchvision viz = visdom.Visdom() db = Pokemon('pokemon', 224, 'train') x, y = next(iter(db)) print('sample:', x.shape, y.shape, y) viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8) # num_workers=8 多线程 for x, label in loader: viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch')) viz.text(str(label.numpy()), win='label', opts=dict(title='batch-label')) time.sleep(10)if __name__ == '__main__': main()

————————pokemon_rapid.py—快速生成自定义数据集文件————————

import torchimport os, globimport random, csvfrom torch.utils.data  import Dataset, DataLoader  #所有自定义数据集的母类from torchvision import transformsfrom PIL import Imagedef main():    import visdom    # 开外网 1:pip install visdom  2:sudo python -m visdom.server ✅    import time    import torchvision    viz = visdom.Visdom()    tf = transforms.Compose([        transforms.Resize((64, 64)),        transforms.ToTensor(),    ])    db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)    loader = DataLoader(db, batch_size=32, shuffle=True)    print(db.class_to_idx)    for x, y in loader:        print(x)        print(x.shape)        viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))      # win 很重要 它代表 窗口        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))   # 也就是说 如果你viz.images&viz.text窗口一样 那就会出现覆盖        time.sleep(10)if __name__ == '__main__':    main()

————————resnet.py—构建网络文件————————

import  torchfrom    torch import  nnfrom    torch.nn import functional as Fclass ResBlk(nn.Module):    """    resnet block    """    def __init__(self, ch_in, ch_out, stride=1):        """        :param ch_in:        :param ch_out:        """        super(ResBlk, self).__init__()        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)        self.bn1 = nn.BatchNorm2d(ch_out)        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)        self.bn2 = nn.BatchNorm2d(ch_out)        self.extra = nn.Sequential()        if ch_out != ch_in:            # [b, ch_in, h, w] => [b, ch_out, h, w]            self.extra = nn.Sequential(                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),                nn.BatchNorm2d(ch_out)            )    def forward(self, x):        """        :param x: [b, ch, h, w]        :return:        """        out = F.relu(self.bn1(self.conv1(x)))        out = self.bn2(self.conv2(out))        # short cut.        # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]        # element-wise add:        out = self.extra(x) + out        out = F.relu(out)        return outclass ResNet18(nn.Module):    def __init__(self, num_class):        super(ResNet18, self).__init__()        self.conv1 = nn.Sequential(            nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),            nn.BatchNorm2d(16)        )        # followed 4 blocks        # [b, 16, h, w] => [b, 32, h ,w]        self.blk1 = ResBlk(16, 32, stride=3)        # [b, 32, h, w] => [b, 64, h, w]        self.blk2 = ResBlk(32, 64, stride=3)        # # [b, 64, h, w] => [b, 128, h, w]        self.blk3 = ResBlk(64, 128, stride=2)        # # [b, 128, h, w] => [b, 256, h, w]        self.blk4 = ResBlk(128, 256, stride=2)        # [b, 256, 7, 7]        self.outlayer = nn.Linear(256*3*3, num_class)    def forward(self, x):        """        :param x:        :return:        """        x = F.relu(self.conv1(x))        # [b, 64, h, w] => [b, 1024, h, w]        x = self.blk1(x)        x = self.blk2(x)        x = self.blk3(x)        x = self.blk4(x)        # print(x.shape)        x = x.view(x.size(0), -1)        x = self.outlayer(x)        return x## def main():#     blk = ResBlk(64, 128)#     tmp = torch.randn(2, 64, 224, 224)#     out = blk(tmp)#     print('block:', out.shape)###     model = ResNet18(5)     # 5类 宠物小精灵#     tmp = torch.randn(2, 3, 224, 224)#     out = model(tmp)#     print('resnet:', out.shape)##     p = sum(map(lambda p:p.numel(), model.parameters()))#     print('parameters size:', p)### if __name__ == '__main__':#     main()

————————train_scratch.py—主文件————————

import torchfrom torch import optim, nnimport visdomimport torchvisionimport timefrom pokemon import Pokemonfrom torch.utils.data import DataLoaderfrom resnet import ResNet18batchsz = 32lr = 0.001epochs = 10torch.manual_seed(1234)'''加载数据集'''train_db = Pokemon('pokemon', 224, mode='train')trian_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)val_db = Pokemon('pokemon', 224, mode='val')val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)test_db = Pokemon('pokemon', 224, mode='test')test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=4)viz = visdom.Visdom()           # 创建visdom工具def evaluate(model, loader):         # val    model.eval()    print('----------val  运行----------')    correct = 0    total = len(loader.dataset)    for x, label in loader:        with torch.no_grad():            logits = model(x)            pred = logits.argmax(dim=1)        correct += torch.eq(pred, label).sum().float().item()    return correct / totaldef evaluate_test(model, loader):         # val    model.eval()    print('----------Test  visdom显示----------')    print()    print('NameList --->',test_db.name2label)    print()    correct = 0    total = len(loader.dataset)    for x, label in loader:        with torch.no_grad():            logits = model(x)            pred = logits.argmax(dim=1)        viz.images(test_db.denormalize(x), nrow=8, win='test_sample_img', opts=dict(title='test_sample_img'))        viz.text(str(pred.numpy()), win='test_sample_text', opts=dict(title='test_sample_text'))        time.sleep(10)        correct += torch.eq(pred, label).sum().float().item()    print('----------Test  visdom显示完毕----------')    return correct / totaldef main():    print('----------Train  训练----------')    print()    print('----------Train  visdom显示----------')    for x, label in trian_loader:       viz.images(train_db.denormalize(x), nrow=8, win='trian_sample_img', opts=dict(title='trian_sample_img'))       viz.text(str(label.numpy()), win='trian_sample_text', opts=dict(title='trian_sample_text'))       time.sleep(2)    print()    print('----------Train  visdom显示完毕----------')    model = ResNet18(5)                 # 最后分为5类    optimizer = optim.Adam(model.parameters(), lr=lr)    criteon = nn.CrossEntropyLoss()    best_acc, best_epoch = 0, 0    global_step = 0    viz.line([0], [-1], win='loss',  opts=dict(title='loss'))                  # 对应参数是(y, x)的顺序    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))             # 初始化    for epoch in range(epochs):        model.train()        for step, (x,label) in enumerate(trian_loader):            logits = model(x)            loss = criteon(logits, label)            optimizer.zero_grad()            loss.backward()            optimizer.step()            viz.line([loss.item()], [global_step], win='loss', update='append')     # update  viz.line 的 win='loss'            global_step += 1        if epoch % 1 == 0:            val_acc = evaluate(model, val_loader)            if val_acc > best_acc:                best_epoch = epoch                best_acc = val_acc                torch.save(model.state_dict(), 'best.mdl')  # 保存最好的模型                viz.line([val_acc], [global_step], win='val_acc', update='append')       # update  viz.line 的 win='val_acc'    print()    print('best acc', best_acc, 'best_epoch', best_epoch)    model.load_state_dict(torch.load('best.mdl'))    # print('loaded from ckpt => model.load_state_dict')    test_acc = evaluate_test(model, test_loader)    print()    print('test_acc', test_acc)if __name__ == '__main__':    main()

效果展示:

在这里插入图片描述

在这里插入图片描述

转载地址:https://codingpark.blog.csdn.net/article/details/105400479 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:LDA主题模型_完整项目_CodingPark编程公园
下一篇:CNN-ResNet_完整项目_CodingPark编程公园

发表评论

最新留言

不错!
[***.144.177.141]2024年04月08日 18时07分07秒