基于卷积神经网络VGG实现水果分类识别
本案例使用对水果数据集进行分类识别,案例详细的讲解了数据读取和预处理,模型介绍,训练,优化,评估,预测,部署这一完整流程,同时提供带有详细注释的代码。

基于卷积神经网络VGG实现水果分类识别
一. 前言
随着人们生活质量的提高,世界各地的水果逐渐进入到大家的生活中,相较于人们日常的大众水果,可能会出现一些人们不认识的新品种,这个时候就需要对这一部分水果进行识别分类。
二. 模型介绍
本案例中我们使用VGG网络进行水果识别,首先我们来了解一下VGG模型。 VGG是当前最流行的CNN模型之一,2014年由Simonyan和Zisserman发表在ICLR 2015会议上的论文《Very Deep Convolutional Networks For Large-scale Image Recognition》提出,其命名来源于论文作者所在的实验室Visual Geometry Group。VGG设计了一种大小为3x3的小尺寸卷积核和池化层组成的基础模块,通过堆叠上述基础模块构造出深度卷积神经网络,该网络在图像分类领域取得了不错的效果,在大型分类数据集ILSVRC上,VGG模型仅有6.8% 的top-5 test error 。VGG模型一经推出就很受研究者们的欢迎,因为其网络结构的设计合理,总体结构简明,且可以适用于多个领域。VGG的设计为后续研究者设计模型结构提供了思路。
下图是VGG-16的网络结构示意图,一共包含13层卷积和3层全连接层。VGG网络使用3×3的卷积层和池化层组成的基础模块来提取特征,三层全连接层放在网络的最后组成分类器,最后一层全连接层的输出即为分类的预测。 在VGG中每层卷积将使用ReLU作为激活函数,在全连接层之后添加dropout来抑制过拟合。使用小的卷积核能够有效地减少参数的个数,使得训练和测试变得更加有效。比如如果我们想要得到感受野为5的特征图,最直接的方法是使用5×5的卷积层,但是我们也可以使用两层3×3卷积层达到同样的效果,并且只需要更少的参数。另外由于卷积核比较小,我们可以堆叠更多的卷积层,提取到更多的图片信息,来提高图像分类的准确率。VGG模型的成功证明了增加网络的深度,可以更好的学习图像中的特征模式,达到更高的分类准确率。
想了解更多关于VGG的知识可以点击了解详细
三. 数据处理
In [1]# 数据集进行解压# ! unzip -oq data/data137852/fruits.zip登录后复制 In [2]
import osimport randomimport jsonimport paddleimport sysimport numpy as npfrom PIL import Imageimport matplotlib.pyplot as plt# 定义公共变量name_dict = {"apple": 0, "banana": 1, "grape": 2, "orange": 3, "pear": 4}data_root_path = "fruits/" # 数据集目录test_file_path = data_root_path + "test.txt" # 测试集文件路径train_file_path = data_root_path + "train.txt" # 测试集文件name_data_list = {} # 记录每个类别图片 key:名称 value:路径列表def save_train_test_file(path, name): # 将图片添加到字典 if name not in name_data_list: # 该类别水果不在字典中 img_list = [] img_list.append(path) # 路径存入列表 name_data_list[name] = img_list # 列表存入字典 else: name_data_list[name].append(path) # 直接添加到列表# 遍历每个子目录,将图片路径存入字典dirs = os.listdir(data_root_path) # 列出数据集下的子目录for d in dirs: full_path = data_root_path + d # 子目录完整路径 if os.path.isdir(full_path): # 如果是目录 imgs = os.listdir(full_path) # 列出子目录下的图片 for img in imgs: img_full_path = full_path + "/" + img # 图片路径 save_train_test_file(img_full_path, d) # 添加到字典 else: # 文件 pass# 划分训练集、测试集with open(test_file_path, "w") as f: passwith open(train_file_path, "w") as f: pass# 遍历字典for name, img_list in name_data_list.items(): i = 0 num = len(img_list) # 取出样本数量 print("%s: %d张图像" % (name, num)) for img in img_list: # 拼接一行 line = "%s\t%d\n" % (img, name_dict[name]) if i % 10 == 0: # 写入测试集 with open(test_file_path, "a") as f: f.write(line) # 存入文件 else: # 写入训练集 with open(train_file_path, "a") as f: f.write(line) # 存入文件 i += 1print("数据预处理完成.")登录后复制 /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized登录后复制
apple: 288张图像banana: 275张图像orange: 276张图像grape: 216张图像pear: 251张图像数据预处理完成.登录后复制 In [3]
from paddle.io import Dataset# 定义数据读取器class dataset(Dataset): def __init__(self, data_path, mode='train'): """ 数据读取器 :param data_path: 数据集所在路径 :param mode: train or eval """ super().__init__() self.data_path = data_path self.img_paths = [] self.labels = [] if mode == 'train': with open(os.path.join(self.data_path, "train.txt"), "r", encoding="utf-8") as f: self.info = f.readlines() for img_info in self.info: img_path, label = img_info.strip().split('\t') self.img_paths.append(img_path) self.labels.append(int(label)) else: with open(os.path.join(self.data_path, "test.txt"), "r", encoding="utf-8") as f: self.info = f.readlines() for img_info in self.info: img_path, label = img_info.strip().split('\t') self.img_paths.append(img_path) self.labels.append(int(label)) def __getitem__(self, index): """ 获取一组数据 :param index: 文件索引号 :return: """ # 第一步打开图像文件并获取label值 img_path = self.img_paths[index] img = Image.open(img_path) if img.mode != 'RGB': img = img.convert('RGB') img = img.resize((224, 224), Image.BILINEAR) #img = rand_flip_image(img) img = np.array(img).astype('float32') img = img.transpose((2, 0, 1)) / 255 label = self.labels[index] label = np.array([label], dtype="int64") return img, label def print_sample(self, index: int = 0): print("文件名", self.img_paths[index], "\t标签值", self.labels[index]) def __len__(self): return len(self.img_paths)登录后复制 In [13]#训练数据加载train_dataset = dataset('fruits',mode='train')train_loader = paddle.io.DataLoader(train_dataset, batch_size=32, shuffle=True)#评估数据加载eval_dataset = dataset('fruits',mode='eval')eval_loader = paddle.io.DataLoader(eval_dataset, batch_size = 8, shuffle=False)print("数据的预处理和加载完成!")登录后复制 数据的预处理和加载完成!登录后复制
四. 模型搭建
4.1 定义卷积池化网络
In [5]# 定义卷积池化网络class ConvPool(paddle.nn.Layer): def __init__(self, num_channels, num_filters, filter_size, pool_size, pool_stride, groups, conv_stride=1, conv_padding=1, ): super(ConvPool, self).__init__() # groups代表卷积层的数量 for i in range(groups): self.add_sublayer( #添加子层实例 'bb_%d' % i, paddle.nn.Conv2D( # layer in_channels=num_channels, #通道数 out_channels=num_filters, #卷积核个数 kernel_size=filter_size, #卷积核大小 stride=conv_stride, #步长 padding = conv_padding, #padding ) ) self.add_sublayer( 'relu%d' % i, paddle.nn.ReLU() ) num_channels = num_filters self.add_sublayer( 'Maxpool', paddle.nn.MaxPool2D( kernel_size=pool_size, #池化核大小 stride=pool_stride #池化步长 ) ) def forward(self, inputs): x = inputs for prefix, sub_layer in self.named_children(): # print(prefix,sub_layer) x = sub_layer(x) return x登录后复制
4.2 搭建VGG网络
In [6]# VGG网络class VGGNet(paddle.nn.Layer): def __init__(self): super(VGGNet, self).__init__() # 5个卷积池化操作 self.convpool01 = ConvPool( 3, 64, 3, 2, 2, 2) #3:通道数,64:卷积核个数,3:卷积核大小,2:池化核大小,2:池化步长,2:连续卷积个数 self.convpool02 = ConvPool( 64, 128, 3, 2, 2, 2) self.convpool03 = ConvPool( 128, 256, 3, 2, 2, 3) self.convpool04 = ConvPool( 256, 512, 3, 2, 2, 3) self.convpool05 = ConvPool( 512, 512, 3, 2, 2, 3) self.pool_5_shape = 512 * 7* 7 # 三个全连接层 self.fc01 = paddle.nn.Linear(self.pool_5_shape, 4096) self.drop1 = paddle.nn.Dropout(p=0.5) self.fc02 = paddle.nn.Linear(4096, 4096) self.drop2 = paddle.nn.Dropout(p=0.5) self.fc03 = paddle.nn.Linear(4096, train_parameters['class_dim']) def forward(self, inputs, label=None): # print('input_shape:', inputs.shape) #[8, 3, 224, 224] """前向计算""" out = self.convpool01(inputs) # print('convpool01_shape:', out.shape) #[8, 64, 112, 112] out = self.convpool02(out) # print('convpool02_shape:', out.shape) #[8, 128, 56, 56] out = self.convpool03(out) # print('convpool03_shape:', out.shape) #[8, 256, 28, 28] out = self.convpool04(out) # print('convpool04_shape:', out.shape) #[8, 512, 14, 14] out = self.convpool05(out) # print('convpool05_shape:', out.shape) #[8, 512, 7, 7] out = paddle.reshape(out, shape=[-1, 512*7*7]) out = self.fc01(out) out = self.drop1(out) out = self.fc02(out) out = self.drop2(out) out = self.fc03(out) if label is not None: acc = paddle.metric.accuracy(input=out, label=label) return out, acc else: return out登录后复制 4.3 参数配置
In [8]train_parameters = { "train_list_path": "fruits/train.txt", #train.txt路径 "eval_list_path": "fruits/test.txt", #eval.txt路径 "class_dim": 5, #分类数}# 参数配置,要保留之前数据集准备阶段配置的参数,所以使用update更新字典train_parameters.update({ "input_size": [3, 224, 224], #输入图片的shape "num_epochs": 35, #训练轮数 "skip_steps": 10, #训练时输出日志的间隔 "save_steps": 100, #训练时保存模型参数的间隔 "learning_strategy": { #优化函数相关的配置 "lr": 0.0001 #超参数学习率 }, "checkpoints": "/home/aistudio/work/checkpoints" #保存的路径})登录后复制 4.4 模型训练
In [ ]model = VGGNet()model.train()# 配置loss函数cross_entropy = paddle.nn.CrossEntropyLoss()# 配置参数优化器optimizer = paddle.optimizer.Adam(learning_rate=train_parameters['learning_strategy']['lr'], parameters=model.parameters()) steps = 0Iters, total_loss, total_acc = [], [], []for epo in range(train_parameters['num_epochs']): for _, data in enumerate(train_loader()): steps += 1 x_data = data[0] y_data = data[1] predicts, acc = model(x_data, y_data) loss = cross_entropy(predicts, y_data) loss.backward() optimizer.step() optimizer.clear_grad() if steps % train_parameters["skip_steps"] == 0: Iters.append(steps) total_loss.append(loss.numpy()[0]) total_acc.append(acc.numpy()[0]) #打印中间过程 print('epo: {}, step: {}, loss is: {}, acc is: {}'\ .format(epo, steps, loss.numpy(), acc.numpy())) #保存模型参数 if steps % train_parameters["save_steps"] == 0: save_path = train_parameters["checkpoints"]+"/"+"save_dir_" + str(steps) + '.pdparams' print('save model to: ' + save_path) paddle.save(model.state_dict(),save_path)paddle.save(model.state_dict(),train_parameters["checkpoints"]+"/"+"save_dir_final.pdparams")登录后复制 4.5 绘制loss和acc图像
In [11]def draw_process(title,color,iters,data,label): plt.title(title, fontsize=24) plt.xlabel("iter", fontsize=20) plt.ylabel(label, fontsize=20) plt.plot(iters, data,color=color,label=label) plt.legend() plt.grid() plt.show()draw_process("trainning loss","red",Iters,total_loss,"trainning loss")draw_process("trainning acc","green",Iters,total_acc,"trainning acc")登录后复制 登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
五. 模型评估
In [12]model__state_dict = paddle.load('work/checkpoints/save_dir_final.pdparams') # 使用保存的最后一个模型model_eval = VGGNet()model_eval.set_state_dict(model__state_dict) model_eval.eval()accs = []# 开始评估for _, data in enumerate(eval_loader()): x_data = data[0] y_data = data[1] predicts = model_eval(x_data) acc = paddle.metric.accuracy(predicts, y_data) accs.append(acc.numpy()[0])print('模型的准确率为:',np.mean(accs))登录后复制 模型的准确率为: 0.9558824登录后复制
六. 模型预测
In [19]def load_image(img_path): img = Image.open(img_path) if img.mode != 'RGB': img = img.convert('RGB') img = img.resize((224, 224), Image.BILINEAR) img = np.array(img).astype('float32') img = img.transpose((2, 0, 1)) / 255 # HWC to CHW 及归一化 return imglabel_dic = {0:"apple", 1:"banana", 2:"grape", 3:"orange", 4:"pear"}登录后复制 In [21]import time# 加载训练过程保存的最后一个模型model__state_dict = paddle.load('work/checkpoints/save_dir_final.pdparams')model_predict = VGGNet()model_predict.set_state_dict(model__state_dict) model_predict.eval()infer_imgs_path = os.listdir("predict")# 预测图片for infer_img_path in infer_imgs_path: infer_img = load_image("predict/"+infer_img_path) infer_img = infer_img[np.newaxis,:, : ,:] #reshape(-1,3,224,224) infer_img = paddle.to_tensor(infer_img) result = model_predict(infer_img) lab = np.argmax(result.numpy()) # print(lab) print("样本: {},被预测为:{}".format(infer_img_path,label_dic[lab])) img = Image.open("predict/"+infer_img_path) plt.imshow(img) plt.axis('off') plt.show() sys.stdout.flush() time.sleep(0.5)登录后复制 1样本: banana.webp,被预测为:banana登录后复制
登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
4样本: pear.webp,被预测为:pear登录后复制
登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
3样本: orange.webp,被预测为:orange登录后复制
登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
2样本: grape.webp,被预测为:grape登录后复制
登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
0样本: apple.webp,被预测为:apple登录后复制
登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制登录后复制
七. 总结
该模型训练过程中选择的优化器是Adam优化器,训练的精度达到了要求,但是也可以选择其他优化器,例如AdamW进行比较,选取最优的。对于超参数学习率来说,该模型采用的是固定常数的学习率,也可以使用具有线性变化的学习率进行训练,有可能会获得更好的模型精度。在合理范围内,增大batch_size会提高显存的利用率,提高大矩阵乘法的并行化效率,减少每个epoch需要训练的迭代次数。
相关攻略
Pywinrm 通过Windows远程管理(WinRM)协议,让Python能够像操作本地一样执行远程Windows命令,真正打通了跨平台管理的最后一公里。 在混合IT环境中,Linux机器管理Wi
早些时候,聊过 Python 领域那场惊心动魄的供应链攻击。当时我就感叹,虽然我们 JavaScript 开发者对这类套路烂熟于心,但亲眼目睹这种规模的“投毒”还是头一次。 早些时候,聊过 Pyth
Toga 是 BeeWare 家族的核心成员,号称“写一次,跑遍所有平台”,而且用的是系统原生控件,不是那种一看就是网页套壳的界面 。 写了这么多年 Python,你是不是也想过:要是能一套代码跑
异常处理的核心:让错误在正确的地方被有效处理。正确的地方,就是别在底层就把异常吞了,也别在顶层还抛裸奔的 Exception。 异常处理写得好,半夜不用起来改 bug。1 你是不是也这么干过?tr
1 Skills机制概述 提起OpenClaw的Skills机制,不少人可能会把它想象成传统意义上的可执行插件。其实,它的内涵要更精妙一些。 简单说,Skills本质上是一套基于提示驱动的能力扩展机制。它并不是一个可以独立“跑”起来的程序模块,而是通过一份结构化描述文件(核心就是那个SKILL m
热门专题
热门推荐
加密货币行业翘首以盼的监管里程碑,终于有了实质性进展。美国证券交易委员会(SEC)主席保罗·阿特金斯(Paul Atkins)近日证实,那份允许加密项目在早期获得注册豁免权的“安全港”框架提案,已经正式送抵白宫,进入了最终审查阶段。 在范德堡大学与区块链协会联合举办的数字资产峰会上,阿特金斯透露了这
微策略Strategy报告:第一季录得144 6亿美元浮亏 再斥资约3 3亿美元买进4871枚比特币 市场震荡的威力有多大?看看Strategy的最新季报就明白了。根据其最新向美国证管会(SEC)提交的8-K报告,受市场剧烈波动影响,这家公司所持的比特币在第一季度录得了一笔惊人的数字——144 6亿
稳定币巨头Tether的动向,向来是加密世界的风向标。这不,它向Web3基础设施的版图扩张,又迈出了关键一步。公司执行长Paolo Ardoino在社交平台X上透露,其工程团队正在全力“烹制”一个新项目——去中心化搜索引擎 “Hypersearch”。这个消息一出,立刻引发了行业的广泛猜想。 采用D
基地位于Coinbase旗下以太坊Layer2网络Base的Seamless Protocol,日前正式宣告了服务的终结。这个曾经吸引了超过20万用户的原生DeFi借贷协议,在运营不到三年后,终究没能跑赢时间。它主打的核心产品是Integrated Leverage Markets(ILMs)——一
PAAL代币揭秘:深度解析Web3社区治理的核心钥匙 在去中心化自治组织的浪潮中,谁真正掌握了项目的话语权?PAAL代币提供了一套系统化的答案。它不仅是生态内流转的价值媒介,更是开启链上治理大门的核心凭证。通过持有并质押PAAL代币,用户能够对协议升级、资金分配乃至战略方向等关键事务投出决定性的一票





