时间:2025-07-21 作者:游乐小编
在图像处理中,关键点本质上是一种特征。它是对一个固定区域或者空间物理关系的抽象描述,描述的是一定邻域范围内的组合或上下文关系。它不仅仅是一个点信息,或代表一个位置,更代表着上下文与周围邻域的组合关系。关键点检测的目标就是通过计算机从图像中找出这些点的坐标,作为计算机视觉领域的一个基础任务,关键点的检测对于高级别任务,例如识别和分类具有至关重要的意义。
本示例教程当前是基于2.0-rc版本Paddle做的案例实现,未来会随着2.0的系列版本发布进行升级。
关键点检测方法总体上可以分成两个类型,一个种是用坐标回归的方式来解决,另一种是将关键点建模成热力图,通过像素分类任务,回归热力图分布得到关键点位置。这两个方法,都是一种手段或者是途径,解决的问题就是要找出这个点在图像当中的位置与关系。
其中人脸关键点检测是关键点检测方法的一个成功实践,本示例简要介绍如何通过飞桨开源框架,实现人脸关键点检测的功能。这个案例用到的是第一种关键点检测方法——坐标回归。我们将使用到新版的Paddle2.0的API,集成式的训练接口,能够很方便对模型进行训练和预测。
导入必要的模块,确认自己的飞桨版本。 如果是cpu环境,请安装cpu版本的Paddle2.0环境,在 paddle.set_device() 输入对应运行设备。
In [1]import numpy as npimport matplotlib.pyplot as pltimport pandas as pdimport osimport paddlefrom paddle.io import Datasetfrom paddle.vision.transforms import transformsfrom paddle.vision.models import resnet18from paddle.nn import functional as Fprint(paddle.__version__)# device = paddle.set_device('gpu') device = paddle.set_device('cpu') # if use static graph, do not setpaddle.disable_static(device)登录后复制
2.0.0-rc0登录后复制
本案例使用了Kaggle最新举办的人脸关键点检测challenge数据集,正式:https://www.kaggle.com/c/facial-keypoints-detection
最新数据集将人脸图像和标注数据打包成了csv文件,我们使用panda来读取。其中数据集中的文件:
training.csv: 包含了用于训练的人脸关键点坐标和图像。
test.csv: 包含了用于测试的人脸关键点图像, 没有标注关键点坐标。
IdLookupTable.csv: 测试集关键点的位置的对应名称。
图像的长和宽都为96像素,所需要检测的一共有15个关键点。
In [3]!unzip -o data/data60/test.zip -d data/data60/!unzip -o data/data60/training.zip -d data/data60登录后复制
unzip: cannot find or open data/data60/test.zip, data/data60/test.zip.zip or data/data60/test.zip.ZIP.登录后复制
飞桨(PaddlePaddle)数据集加载方案是统一使用Dataset(数据集定义) + DataLoader(多进程数据集加载)。
首先我们先进行数据集的定义,数据集定义主要是实现一个新的Dataset类,继承父类paddle.io.Dataset,并实现父类中以下两个抽象方法,getitem__和__len:
In [4]Train_Dir = 'data/data60//training.csv'Test_Dir = 'data/data60//test.csv'lookid_dir = 'data/data60/IdLookupTable.csv'class ImgTransforms(object): """ 图像预处理工具,用于将图像进行升维(96, 96) => (96, 96, 3), 并对图像的维度进行转换从HWC变为CHW """ def __init__(self, fmt): self.format = fmt def __call__(self, img): if len(img.shape) == 2: img = np.expand_dims(img, axis=2) img = img.transpose(self.format) if img.shape[0] == 1: img = np.repeat(img, 3, axis=0) return imgclass FaceDataset(Dataset): def __init__(self, data_path, mode='train', val_split=0.2): self.mode = mode assert self.mode in ['train', 'val', 'test'], \ "mode should be 'train' or 'test', but got {}".format(self.mode) self.data_source = pd.read_csv(data_path) # 清洗数据, 数据集中有很多样本只标注了部分关键点, 这里有两种策略 # 第一种, 将未标注的位置从上一个样本对应的关键点复制过来 # self.data_source.fillna(method = 'ffill',inplace = True) # 第二种, 将包含有未标注的样本从数据集中移除 self.data_source.dropna(how="any", inplace=True) self.data_label_all = self.data_source.drop('Image', axis = 1) # 划分训练集和验证集合 if self.mode in ['train', 'val']: np.random.seed(43) data_len = len(self.data_source) # 随机划分 shuffled_indices = np.random.permutation(data_len) # 顺序划分 # shuffled_indices = np.arange(data_len) self.shuffled_indices = shuffled_indices val_set_size = int(data_len*val_split) if self.mode == 'val': val_indices = shuffled_indices[:val_set_size] self.data_img = self.data_source.reindex().iloc[val_indices] self.data_label = self.data_label_all.reindex().iloc[val_indices] elif self.mode == 'train': train_indices = shuffled_indices[val_set_size:] self.data_img = self.data_source.reindex().iloc[train_indices] self.data_label = self.data_label_all.reindex().iloc[train_indices] elif self.mode == 'test': self.data_img = self.data_source self.data_label = self.data_label_all self.transforms = transforms.Compose([ ImgTransforms((2, 0, 1)) ]) # 每次迭代时返回数据和对应的标签 def __getitem__(self, idx): img = self.data_img['Image'].iloc[idx].split(' ') img = ['0' if x == '' else x for x in img] img = np.array(img, dtype = 'float32').reshape(96, 96) img = self.transforms(img) label = np.array(self.data_label.iloc[idx,:],dtype = 'float32')/96 return img, label # 返回整个数据集的总数 def __len__(self): return len(self.data_img)# 训练数据集和验证数据集train_dataset = FaceDataset(Train_Dir, mode='train')val_dataset = FaceDataset(Train_Dir, mode='val')# 测试数据集test_dataset = FaceDataset(Test_Dir, mode='test')登录后复制
实现好Dataset数据集后,我们来测试一下数据集是否符合预期,因为Dataset是一个可以被迭代的Class,我们通过for循环从里面读取数据来用matplotlib进行展示。关键点的坐标在数据集中进行了归一化处理,这里乘以图像的大小恢复到原始尺度,并用scatter函数将点画在输出的图像上。
In [6]def plot_sample(x, y, axis): img = x.reshape(96, 96) axis.imshow(img, cmap='gray') axis.scatter(y[0::2], y[1::2], marker='x', s=10, color='b')fig = plt.figure(figsize=(10, 7))fig.subplots_adjust( left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)# 随机取16个样本展示for i in range(16): axis = fig.add_subplot(4, 4, i+1, xticks=[], yticks=[]) idx = np.random.randint(train_dataset.__len__()) # print(idx) img, label = train_dataset[idx] label = label*96 plot_sample(img[0], label, axis)plt.show()登录后复制
登录后复制登录后复制登录后复制
这里使用到paddle.vision.models中定义的resnet18网络模型。在ImageNet分类任务中,图像分成1000类,在模型后接一个全连接层,将输出的1000维向量映射成30维,对应15个关键点的横纵坐标。
In [2]class FaceNet(paddle.nn.Layer): def __init__(self, num_keypoints, pretrained=False): super(FaceNet, self).__init__() self.backbone = resnet18(pretrained) self.outLayer1 = paddle.nn.Sequential( paddle.nn.Linear(1000, 512), paddle.nn.ReLU(), paddle.nn.Dropout(0.1)) self.outLayer2 = paddle.nn.Linear(512, num_keypoints*2) def forward(self, inputs): out = self.backbone(inputs) out = self.outLayer1(out) out = self.outLayer2(out) return out登录后复制
调用飞桨提供的summary接口对组建好的模型进行可视化,方便进行模型结构和参数信息的查看和确认。
In [3]from paddle.static import InputSpecpaddle.disable_static()num_keypoints = 15model = paddle.Model(FaceNet(num_keypoints))# 输入数据大小:batch_size: 3, channel:3, width:96, height:96model.summary((3, 3, 96, 96))登录后复制
------------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =============================================================================== Conv2D-1 [[3, 3, 96, 96]] [3, 64, 48, 48] 9,408 BatchNorm2D-1 [[3, 64, 48, 48]] [3, 64, 48, 48] 256 ReLU-1 [[3, 64, 48, 48]] [3, 64, 48, 48] 0 MaxPool2D-1 [[3, 64, 48, 48]] [3, 64, 24, 24] 0 Conv2D-2 [[3, 64, 24, 24]] [3, 64, 24, 24] 36,864 BatchNorm2D-2 [[3, 64, 24, 24]] [3, 64, 24, 24] 256 ReLU-2 [[3, 64, 24, 24]] [3, 64, 24, 24] 0 Conv2D-3 [[3, 64, 24, 24]] [3, 64, 24, 24] 36,864 BatchNorm2D-3 [[3, 64, 24, 24]] [3, 64, 24, 24] 256 BasicBlock-1 [[3, 64, 24, 24]] [3, 64, 24, 24] 0 Conv2D-4 [[3, 64, 24, 24]] [3, 64, 24, 24] 36,864 BatchNorm2D-4 [[3, 64, 24, 24]] [3, 64, 24, 24] 256 ReLU-3 [[3, 64, 24, 24]] [3, 64, 24, 24] 0 Conv2D-5 [[3, 64, 24, 24]] [3, 64, 24, 24] 36,864 BatchNorm2D-5 [[3, 64, 24, 24]] [3, 64, 24, 24] 256 BasicBlock-2 [[3, 64, 24, 24]] [3, 64, 24, 24] 0 Conv2D-7 [[3, 64, 24, 24]] [3, 128, 12, 12] 73,728 BatchNorm2D-7 [[3, 128, 12, 12]] [3, 128, 12, 12] 512 ReLU-4 [[3, 128, 12, 12]] [3, 128, 12, 12] 0 Conv2D-8 [[3, 128, 12, 12]] [3, 128, 12, 12] 147,456 BatchNorm2D-8 [[3, 128, 12, 12]] [3, 128, 12, 12] 512 Conv2D-6 [[3, 64, 24, 24]] [3, 128, 12, 12] 8,192 BatchNorm2D-6 [[3, 128, 12, 12]] [3, 128, 12, 12] 512 BasicBlock-3 [[3, 64, 24, 24]] [3, 128, 12, 12] 0 Conv2D-9 [[3, 128, 12, 12]] [3, 128, 12, 12] 147,456 BatchNorm2D-9 [[3, 128, 12, 12]] [3, 128, 12, 12] 512 ReLU-5 [[3, 128, 12, 12]] [3, 128, 12, 12] 0 Conv2D-10 [[3, 128, 12, 12]] [3, 128, 12, 12] 147,456 BatchNorm2D-10 [[3, 128, 12, 12]] [3, 128, 12, 12] 512 BasicBlock-4 [[3, 128, 12, 12]] [3, 128, 12, 12] 0 Conv2D-12 [[3, 128, 12, 12]] [3, 256, 6, 6] 294,912 BatchNorm2D-12 [[3, 256, 6, 6]] [3, 256, 6, 6] 1,024 ReLU-6 [[3, 256, 6, 6]] [3, 256, 6, 6] 0 Conv2D-13 [[3, 256, 6, 6]] [3, 256, 6, 6] 589,824 BatchNorm2D-13 [[3, 256, 6, 6]] [3, 256, 6, 6] 1,024 Conv2D-11 [[3, 128, 12, 12]] [3, 256, 6, 6] 32,768 BatchNorm2D-11 [[3, 256, 6, 6]] [3, 256, 6, 6] 1,024 BasicBlock-5 [[3, 128, 12, 12]] [3, 256, 6, 6] 0 Conv2D-14 [[3, 256, 6, 6]] [3, 256, 6, 6] 589,824 BatchNorm2D-14 [[3, 256, 6, 6]] [3, 256, 6, 6] 1,024 ReLU-7 [[3, 256, 6, 6]] [3, 256, 6, 6] 0 Conv2D-15 [[3, 256, 6, 6]] [3, 256, 6, 6] 589,824 BatchNorm2D-15 [[3, 256, 6, 6]] [3, 256, 6, 6] 1,024 BasicBlock-6 [[3, 256, 6, 6]] [3, 256, 6, 6] 0 Conv2D-17 [[3, 256, 6, 6]] [3, 512, 3, 3] 1,179,648 BatchNorm2D-17 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,048 ReLU-8 [[3, 512, 3, 3]] [3, 512, 3, 3] 0 Conv2D-18 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,359,296 BatchNorm2D-18 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,048 Conv2D-16 [[3, 256, 6, 6]] [3, 512, 3, 3] 131,072 BatchNorm2D-16 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,048 BasicBlock-7 [[3, 256, 6, 6]] [3, 512, 3, 3] 0 Conv2D-19 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,359,296 BatchNorm2D-19 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,048 ReLU-9 [[3, 512, 3, 3]] [3, 512, 3, 3] 0 Conv2D-20 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,359,296 BatchNorm2D-20 [[3, 512, 3, 3]] [3, 512, 3, 3] 2,048 BasicBlock-8 [[3, 512, 3, 3]] [3, 512, 3, 3] 0 AdaptiveAvgPool2D-1 [[3, 512, 3, 3]] [3, 512, 1, 1] 0 Linear-1 [[3, 512]] [3, 1000] 513,000 ResNet-1 [[3, 3, 96, 96]] [3, 1000] 0 Linear-2 [[3, 1000]] [3, 512] 512,512 ReLU-10 [[3, 512]] [3, 512] 0 Dropout-1 [[3, 512]] [3, 512] 0 Linear-3 [[3, 512]] [3, 30] 15,390 ===============================================================================Total params: 12,227,014Trainable params: 12,207,814Non-trainable params: 19,200-------------------------------------------------------------------------------Input size (MB): 0.32Forward/backward pass size (MB): 31.52Params size (MB): 46.64Estimated Total Size (MB): 78.48-------------------------------------------------------------------------------登录后复制
{'total_params': 12227014, 'trainable_params': 12207814}登录后复制
在这个任务是对坐标进行回归,我们使用均方误差(Mean Square error )损失函数paddle.nn.MSELoss()来做计算,飞桨2.0中,在nn下将损失函数封装成可调用类。我们这里使用paddle.Model相关的API直接进行训练,只需要定义好数据集、网络模型和损失函数即可。
使用模型代码进行Model实例生成,使用prepare接口定义优化器、损失函数和评价指标等信息,用于后续训练使用。在所有初步配置完成后,调用fit接口开启训练执行过程,调用fit时只需要将前面定义好的训练数据集、测试数据集、训练轮次(Epoch)和批次大小(batch_size)配置好即可。
In [9]model = paddle.Model(FaceNet(num_keypoints=15))optim = paddle.optimizer.Adam(learning_rate=1e-3, parameters=model.parameters())model.prepare(optim, paddle.nn.MSELoss())model.fit(train_dataset, val_dataset, epochs=60, batch_size=256)登录后复制
Epoch 1/60step 7/7 - loss: 0.2203 - 570ms/stepEval begin...step 2/2 - loss: 0.2003 - 476ms/stepEval samples: 428Epoch 2/60step 7/7 - loss: 0.1293 - 574ms/stepEval begin...step 2/2 - loss: 0.1363 - 454ms/stepEval samples: 428Epoch 3/60step 7/7 - loss: 0.0499 - 799ms/stepEval begin...step 2/2 - loss: 0.0530 - 455ms/stepEval samples: 428Epoch 4/60step 7/7 - loss: 0.0128 - 557ms/stepEval begin...step 2/2 - loss: 0.0124 - 454ms/stepEval samples: 428Epoch 5/60step 7/7 - loss: 0.0079 - 555ms/stepEval begin...step 2/2 - loss: 0.0068 - 474ms/stepEval samples: 428Epoch 6/60step 7/7 - loss: 0.0028 - 573ms/stepEval begin...step 2/2 - loss: 0.0046 - 454ms/stepEval samples: 428Epoch 7/60step 7/7 - loss: 0.0021 - 559ms/stepEval begin...step 2/2 - loss: 0.0046 - 469ms/stepEval samples: 428Epoch 8/60step 7/7 - loss: 0.0020 - 560ms/stepEval begin...step 2/2 - loss: 0.0024 - 461ms/stepEval samples: 428Epoch 9/60step 7/7 - loss: 0.0018 - 561ms/stepEval begin...step 2/2 - loss: 0.0017 - 454ms/stepEval samples: 428Epoch 10/60step 7/7 - loss: 0.0018 - 563ms/stepEval begin...step 2/2 - loss: 0.0014 - 454ms/stepEval samples: 428Epoch 11/60step 7/7 - loss: 0.0018 - 560ms/stepEval begin...step 2/2 - loss: 0.0011 - 454ms/stepEval samples: 428Epoch 12/60step 7/7 - loss: 0.0016 - 554ms/stepEval begin...step 2/2 - loss: 0.0010 - 456ms/stepEval samples: 428Epoch 13/60step 7/7 - loss: 0.0018 - 559ms/stepEval begin...step 2/2 - loss: 9.8745e-04 - 453ms/stepEval samples: 428Epoch 14/60step 7/7 - loss: 0.0019 - 616ms/stepEval begin...step 2/2 - loss: 9.8130e-04 - 478ms/stepEval samples: 428Epoch 15/60step 7/7 - loss: 0.0015 - 553ms/stepEval begin...step 2/2 - loss: 9.7889e-04 - 454ms/stepEval samples: 428Epoch 16/60step 7/7 - loss: 0.0017 - 567ms/stepEval begin...step 2/2 - loss: 9.5601e-04 - 463ms/stepEval samples: 428Epoch 17/60step 7/7 - loss: 0.0017 - 565ms/stepEval begin...step 2/2 - loss: 9.4202e-04 - 462ms/stepEval samples: 428Epoch 18/60step 7/7 - loss: 0.0017 - 559ms/stepEval begin...step 2/2 - loss: 0.0010 - 457ms/stepEval samples: 428Epoch 19/60step 7/7 - loss: 0.0017 - 567ms/stepEval begin...step 2/2 - loss: 9.7536e-04 - 465ms/stepEval samples: 428Epoch 20/60step 7/7 - loss: 0.0015 - 565ms/stepEval begin...step 2/2 - loss: 9.0152e-04 - 468ms/stepEval samples: 428Epoch 21/60step 7/7 - loss: 0.0014 - 570ms/stepEval begin...step 2/2 - loss: 9.0222e-04 - 462ms/stepEval samples: 428Epoch 22/60step 7/7 - loss: 0.0015 - 568ms/stepEval begin...step 2/2 - loss: 7.8668e-04 - 453ms/stepEval samples: 428Epoch 23/60step 7/7 - loss: 0.0013 - 557ms/stepEval begin...step 2/2 - loss: 7.7608e-04 - 456ms/stepEval samples: 428Epoch 24/60step 7/7 - loss: 0.0013 - 555ms/stepEval begin...step 2/2 - loss: 7.8894e-04 - 458ms/stepEval samples: 428Epoch 25/60step 7/7 - loss: 0.0013 - 566ms/stepEval begin...step 2/2 - loss: 7.6504e-04 - 456ms/stepEval samples: 428Epoch 26/60step 7/7 - loss: 0.0012 - 578ms/stepEval begin...step 2/2 - loss: 7.1451e-04 - 452ms/stepEval samples: 428Epoch 27/60step 7/7 - loss: 0.0012 - 572ms/stepEval begin...step 2/2 - loss: 7.3616e-04 - 749ms/stepEval samples: 428Epoch 28/60step 7/7 - loss: 0.0013 - 553ms/stepEval begin...step 2/2 - loss: 7.3413e-04 - 457ms/stepEval samples: 428Epoch 29/60step 7/7 - loss: 0.0011 - 560ms/stepEval begin...step 2/2 - loss: 7.0764e-04 - 457ms/stepEval samples: 428Epoch 30/60step 7/7 - loss: 0.0011 - 560ms/stepEval begin...step 2/2 - loss: 7.2735e-04 - 457ms/stepEval samples: 428Epoch 31/60step 7/7 - loss: 0.0010 - 567ms/stepEval begin...step 2/2 - loss: 7.9266e-04 - 453ms/stepEval samples: 428Epoch 32/60step 7/7 - loss: 0.0012 - 561ms/stepEval begin...step 2/2 - loss: 6.9983e-04 - 458ms/stepEval samples: 428Epoch 33/60step 7/7 - loss: 0.0010 - 568ms/stepEval begin...step 2/2 - loss: 6.7500e-04 - 457ms/stepEval samples: 428Epoch 34/60step 7/7 - loss: 0.0011 - 568ms/stepEval begin...step 2/2 - loss: 6.6316e-04 - 459ms/stepEval samples: 428Epoch 35/60step 7/7 - loss: 0.0011 - 561ms/stepEval begin...step 2/2 - loss: 6.3884e-04 - 467ms/stepEval samples: 428Epoch 36/60step 7/7 - loss: 9.9395e-04 - 559ms/stepEval begin...step 2/2 - loss: 6.0092e-04 - 456ms/stepEval samples: 428Epoch 37/60step 7/7 - loss: 0.0011 - 550ms/stepEval begin...step 2/2 - loss: 5.8750e-04 - 460ms/stepEval samples: 428Epoch 38/60step 7/7 - loss: 9.3226e-04 - 565ms/stepEval begin...step 2/2 - loss: 5.3959e-04 - 462ms/stepEval samples: 428Epoch 39/60step 7/7 - loss: 9.2422e-04 - 562ms/stepEval begin...step 2/2 - loss: 5.2893e-04 - 454ms/stepEval samples: 428Epoch 40/60step 7/7 - loss: 9.9294e-04 - 558ms/stepEval begin...step 2/2 - loss: 5.6088e-04 - 461ms/stepEval samples: 428Epoch 41/60step 7/7 - loss: 9.5442e-04 - 551ms/stepEval begin...step 2/2 - loss: 5.2197e-04 - 454ms/stepEval samples: 428Epoch 42/60step 7/7 - loss: 9.2407e-04 - 572ms/stepEval begin...step 2/2 - loss: 5.5484e-04 - 473ms/stepEval samples: 428Epoch 43/60step 7/7 - loss: 9.3951e-04 - 576ms/stepEval begin...step 2/2 - loss: 5.0847e-04 - 483ms/stepEval samples: 428Epoch 44/60step 7/7 - loss: 9.3463e-04 - 566ms/stepEval begin...step 2/2 - loss: 5.2639e-04 - 457ms/stepEval samples: 428Epoch 45/60step 7/7 - loss: 9.6260e-04 - 863ms/stepEval begin...step 2/2 - loss: 5.2387e-04 - 459ms/stepEval samples: 428Epoch 46/60step 7/7 - loss: 7.9236e-04 - 562ms/stepEval begin...step 2/2 - loss: 5.2232e-04 - 488ms/stepEval samples: 428Epoch 47/60step 7/7 - loss: 8.8867e-04 - 561ms/stepEval begin...step 2/2 - loss: 5.7546e-04 - 472ms/stepEval samples: 428Epoch 48/60step 7/7 - loss: 9.2187e-04 - 563ms/stepEval begin...step 2/2 - loss: 5.1036e-04 - 467ms/stepEval samples: 428Epoch 49/60step 7/7 - loss: 8.4267e-04 - 573ms/stepEval begin...step 2/2 - loss: 5.0897e-04 - 457ms/stepEval samples: 428Epoch 50/60step 7/7 - loss: 8.1840e-04 - 559ms/stepEval begin...step 2/2 - loss: 4.8478e-04 - 460ms/stepEval samples: 428Epoch 51/60step 7/7 - loss: 7.9169e-04 - 564ms/stepEval begin...step 2/2 - loss: 4.9722e-04 - 457ms/stepEval samples: 428Epoch 52/60step 7/7 - loss: 7.4572e-04 - 561ms/stepEval begin...step 2/2 - loss: 4.6681e-04 - 459ms/stepEval samples: 428Epoch 53/60step 7/7 - loss: 7.7328e-04 - 563ms/stepEval begin...step 2/2 - loss: 4.4126e-04 - 468ms/stepEval samples: 428Epoch 54/60step 7/7 - loss: 8.3519e-04 - 565ms/stepEval begin...step 2/2 - loss: 4.5718e-04 - 461ms/stepEval samples: 428Epoch 55/60step 7/7 - loss: 7.3492e-04 - 586ms/stepEval begin...step 2/2 - loss: 4.7409e-04 - 461ms/stepEval samples: 428Epoch 56/60step 7/7 - loss: 7.7133e-04 - 561ms/stepEval begin...step 2/2 - loss: 4.4491e-04 - 462ms/stepEval samples: 428Epoch 57/60step 7/7 - loss: 7.5041e-04 - 561ms/stepEval begin...step 2/2 - loss: 4.4271e-04 - 843ms/stepEval samples: 428Epoch 58/60step 7/7 - loss: 7.5371e-04 - 848ms/stepEval begin...step 2/2 - loss: 4.9145e-04 - 451ms/stepEval samples: 428Epoch 59/60step 7/7 - loss: 7.7562e-04 - 570ms/stepEval begin...step 2/2 - loss: 4.1318e-04 - 456ms/stepEval samples: 428Epoch 60/60step 7/7 - loss: 7.3029e-04 - 858ms/stepEval begin...step 2/2 - loss: 4.2197e-04 - 454ms/stepEval samples: 428登录后复制
为了更好的观察预测结果,我们分别可视化验证集结果与标注点的对比,和在未标注的测试集的预测结果。
红色的关键点为网络预测的结果, 绿色的关键点为标注的groundtrue。
In [10]result = model.predict(val_dataset, batch_size=1)登录后复制
Predict begin...step 428/428 [==============================] - 17ms/step Predict samples: 428登录后复制 In [28]
def plot_sample(x, y, axis, gt=[]): img = x.reshape(96, 96) axis.imshow(img, cmap='gray') axis.scatter(y[0::2], y[1::2], marker='x', s=10, color='r') if gt!=[]: axis.scatter(gt[0::2], gt[1::2], marker='x', s=10, color='lime')fig = plt.figure(figsize=(10, 7))fig.subplots_adjust( left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)for i in range(16): axis = fig.add_subplot(4, 4, i+1, xticks=[], yticks=[]) idx = np.random.randint(val_dataset.__len__()) img, gt_label = val_dataset[idx] gt_label = gt_label*96 label_pred = result[0][idx].reshape(-1) label_pred = label_pred*96 plot_sample(img[0], label_pred, axis, gt_label)plt.show()登录后复制
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/ipykernel_launcher.py:5: DeprecationWarning: elementwise comparison failed; this will raise an error in the future. """登录后复制
登录后复制登录后复制登录后复制
result = model.predict(test_dataset, batch_size=1)登录后复制
Predict begin...step 1783/1783 [==============================] - 17ms/step Predict samples: 1783登录后复制 In [25]
fig = plt.figure(figsize=(10, 7))fig.subplots_adjust( left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)for i in range(16): axis = fig.add_subplot(4, 4, i+1, xticks=[], yticks=[]) idx = np.random.randint(test_dataset.__len__()) img, _ = test_dataset[idx] label_pred = result[0][idx].reshape(-1) label_pred = label_pred*96 plot_sample(img[0], label_pred, axis)plt.show()登录后复制
登录后复制登录后复制登录后复制
2021-11-05 11:52
手游攻略2021-11-19 18:38
手游攻略2021-10-31 23:18
手游攻略2022-06-03 14:46
游戏资讯2025-06-28 12:37
单机攻略