时间:2025-07-22 作者:游乐小编
本文介绍CVPR2024论文提出的小样本学习模型FRN,其将分类问题归为特征重构问题,以闭合解形式从支持样本回归查询样本特征,性能与效率更优。文中展示了基于PaddlePaddle复现的FRN在mini-ImageNet上的精度,还介绍了数据集、环境依赖、快速开始步骤、代码结构及模型信息等内容。
论文Few-Shot Classification with Feature Map Reconstruction Networks是顶会CVPR2024上发表的一种小样本学习经典方法。该方法在小样本学习的benchmark上依然具有最佳的性能指标,是该领域的重要方法。
FRN将小样本分类问题归结为潜在空间中的特征重构问题。作者认为,通过支持样本重构查询样本特征的能力,决定了查询样本的所属类别。作者在小样本学习中引入了一种新的机制,以闭合解的形式从支持样本特征直接向查询样本特征做回归,无需引入新的模块或者大规模的训练参数。上述方法得到的模型(FRN),相比先前的其他方法,无论在计算效率上还是性能表现上都更有优势。FRN在四个细粒度数据集上展现出实质性提升。在通用的粗粒度数据集mini-ImageNet和tiered-ImageNet上,也达到了SOTA指标。
下图展示了FRN的基本工作流程。
基于paddlepaddle深度学习框架,对文献算法进行复现后,本项目在mini-ImageNet上达到的测试精度,如下表所示。
模型训练包括了两个过程,首先是模型预训练,按照典型分类网络的训练过程,将整个训练集送入backbone进行训练;然后是微调过程,按照episode training的训练范式,配置为20-Way 5-Shot方式进行微调训练。这两个训练过程的训练超参数设置如下:
(1)预训练过程
(2)微调训练过程
miniImageNet数据集节选自ImageNet数据集。 DeepMind团队首次将miniImageNet数据集用于小样本学习研究,从此miniImageNet成为了元学习和小样本领域的基准数据集。 关于该数据集的介绍可以参考https://blog.csdn.net/wangkaidehao/article/details/105531837
miniImageNet是由Oriol Vinyals等在Matching Networks 中首次提出的,该文献是小样本分类任务的开山制作,也是本次复现论文关于该数据集的参考文献。在Matching Networks中, 作者提出对ImageNet中的类别和样本进行抽取(参见其Appendix B),形成了一个数据子集,将其命名为miniImageNet。 划分方法,作者仅给出了一个文本文件进行说明。 Vinyals在文中指明了miniImageNet图片尺寸为84x84。因此,后续小样本领域的研究者,均是基于原始图像,在代码中进行预处理, 将图像缩放到84x84的规格。
至于如何缩放到84x84,本领域研究者各有各的方法,通常与研究者的个人理解相关,但一般对实验结果影响不大。本次文献论文原文,未能给出 miniImageNet的具体实现方法,本项目即参考领域内较为通用的预处理方法进行处理。
数据集大小:miniImageNet包含100类共60000张彩色图片,其中每类有600个样本。 mini-imagenet一共有2.86GB数据格式:|- miniImagenet| |- images/| | |- n0153282900000005.jpg | | |- n0153282900000006.jpg| | |- …| |- train.csv| |- test.csv| |- val.csv登录后复制
数据集链接:miniImagenet
硬件:
x86 cpuNVIDIA GPU框架:
PaddlePaddle = 2.4其他依赖项:
numpy==1.19.3tqdm==4.59.0Pillow==8.3.1!unzip -n -d ./data/ ./data/data105646/mini-imagenet-sxc.zip
In [ ]%cd /home/aistudio/!unzip -n -d ./data/ ./data/data105646/mini-imagenet-sxc.zip登录后复制 In [ ]
%cd /home/aistudio/work/!unzip -o frn.zip登录后复制 In [ ]
# 生成json文件!cp write_miniImagenet_filelist.py /home/aistudio/data/mini-imagenet-sxc/%cd /home/aistudio/data/mini-imagenet-sxc/!python write_miniImagenet_filelist.py登录后复制
python pretrain.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method stl_frn --lr 1e-1 --gamma 1e-1 --epoch 350 --milestones 200 300 --batch_size 512 --val_n_episode 600 --image_size 84 --model ResNet12 --n_shot 1 --n_query 15 --gpu登录后复制
模型开始训练,运行完毕后,训练log和模型参数保存在./checkpoints/mini_imagenet/ResNet12_stl_frn_pretrain/目录下,分别是:
best_model.pdparams # 最优模型参数文件output.log # 训练LOG信息登录后复制登录后复制
训练完成后,可将上述文件手动保存到其他目录下,避免被后续训练操作覆盖。
In [ ]%cd /home/aistudio/work!python pretrain.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method stl_frn --lr 1e-1 --gamma 1e-1 --epoch 350 --milestones 200 300 --batch_size 512 --val_n_episode 600 --image_size 84 --model ResNet12 --n_shot 1 --n_query 15 --gpu登录后复制
python meta_train.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method meta_frn --lr 1e-3 --gamma 1e-1 --epoch 150 --train_n_episode 1000 --val_n_episode 600 --milestones 70 120 --image_size 84 --model ResNet12 --train_n_way 20 --val_n_way 5 --n_shot 5 --n_query 15 --gpu --pretrain_path ./checkpoints/mini_imagenet/ResNet12_stl_frn_pretrain/best_model.pdparams登录后复制
模型开始训练,运行完毕后,训练log和模型参数保存在./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/目录下,分别是:
best_model.pdparams # 最优模型参数文件output.log # 训练LOG信息登录后复制登录后复制
训练完成后,可将上述文件手动保存到其他目录下,避免被后续训练操作覆盖。
In [ ]%cd /home/aistudio/work!python meta_train.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --method meta_frn --lr 1e-3 --gamma 1e-1 --epoch 150 --train_n_episode 1000 --val_n_episode 600 --milestones 70 120 --image_size 84 --model ResNet12 --train_n_way 20 --val_n_way 5 --n_shot 5 --n_query 15 --gpu --pretrain_path ./checkpoints/mini_imagenet/ResNet12_stl_frn_pretrain/best_model.pdparams登录后复制
python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method meta_frn --image_size 84 --gpu --n_shot 1 --model_path ./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/best_model.pdparams --test_task_nums 1 --test_n_episode 600登录后复制
用于评估模型在小样本任务下的精度。
In [ ]# 5-Way 1-Shot评估%cd /home/aistudio/work!python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method meta_frn --image_size 84 --gpu --n_shot 1 --model_path ./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/best_model.pdparams --test_task_nums 1 --test_n_episode 600登录后复制 In [ ]
# 5-Way 5-Shot评估%cd /home/aistudio/work!python test.py --dataset mini_imagenet --data_path /home/aistudio/data/mini-imagenet-sxc --model ResNet12 --method meta_frn --image_size 84 --gpu --n_shot 5 --model_path ./checkpoints/mini_imagenet/ResNet12_meta_frn_20way_5shot_metatrain/best_model.pdparams --test_task_nums 1 --test_n_episode 600登录后复制
├── data # 数据处理相关│ ├── datamgr.py # data manager模块│ ├── dataset.py # data set模块├── methods # 模型相关│ ├── FRN.py # FRN核心算法├── network # backbone│ ├── conv.py # Conv-4和Conv-6代码实现│ ├── resnet.py # ResNet-12代码实现├── scripts # 运行工程脚本│ ├── mini_imagenet │ │ ├── run_frn │ │ │ ├── run_frn_metatrain.sh # 运行微调训练│ │ │ ├── run_frn_pretrain.sh # 运行预训练│ │ │ ├── run_frn_test.sh # 运行测试├── meta_train.py # 微调训练代码├── pretrain.py # 预训练代码├── test.py # 测试代码├── utils.py # 公共调用函数├── wirite_miniImagenet_filelist.py # 生成mini-ImageNet数据json文件登录后复制
可以在 pretrain.py 中设置训练与评估相关参数,具体如下:
可参考快速开始章节中的描述
执行训练开始后,将得到类似如下的输出。每一轮epoch训练将会打印当前training loss、training acc、val loss、val acc以及训练kl散度。
Epoch 0 | Batch 0/150 | Loss 4.158544best model! save...val loss is 0.00, val acc is 37.46model best acc is 37.46, best acc epoch is 0This epoch use 7.61 minutestrain loss is 3.72, train acc is 10.84Epoch 1 | Batch 0/150 | Loss 3.052964val loss is 0.00, val acc is 37.46model best acc is 37.46, best acc epoch is 0This epoch use 3.73 minutestrain loss is 2.96, train acc is 25.28Epoch 2 | Batch 0/150 | Loss 2.588413val loss is 0.00, val acc is 37.46model best acc is 37.46, best acc epoch is 0This epoch use 3.71 minutestrain loss is 2.59, train acc is 33.27...登录后复制
可参考快速开始章节中的描述
此时的输出为:
登录后复制
训练完成后,模型和相关LOG保存在./results/5w1s和./results/5w5s目录下。
训练和测试日志保存在results目录下。
2021-11-05 11:52
手游攻略2021-11-19 18:38
手游攻略2021-10-31 23:18
手游攻略2022-06-03 14:46
游戏资讯2025-06-28 12:37
单机攻略