首页 游戏 软件 资讯 排行榜 专题
首页
AI
ConvMixer:Patches are all you need?

ConvMixer:Patches are all you need?

热心网友
87
转载
2025-07-18
ConvMixer是基于卷积层进行Mixer操作的模型,结构简单却精度不错。它与MLP Mixer类似,通过交替混合channel和token维度信息提取图像特征,但用卷积替代MLP。其用逐通道卷积提取token信息,1x1卷积提取channel信息,最新提供三个预训练模型,在ImageNet 1k验证集上表现良好,还可从头或微调训练。

convmixer:patches are all you need? - 游乐网

引入

之前介绍了 MLP-Mixer,【MLP-Mixer:MLP is all you need ?】那么除了 MLP 其他的基础网络层可不可以也进行 Mixer 操作呢?结论当然也是可以的,所以这次就来介绍一个最近新鲜出炉的工作 ConvMixer。顾名思义 ConvMixer 就是使用卷积层进行 Mixer 操作来构建的一个模型结构上也非常简单,但是同样能够实现一个不错的精度表现

相关资料

论文:"Patches Are All You Need?"最新代码:tmp-iclr/convmixer

模型架构

ConvMixer 与 MLP Mixer 模型一样模型的结构都十分简单

免费影视、动漫、音乐、游戏、小说资源长期稳定更新! 👉 点此立即查看 👈

同样是通过 channel 和 token 两个维度的信息进行交替混合,实现图像特征的有效提取

只不过 ConvMixer 使用的基础网络层为卷积,而 MLP Mixer 使用的是 MLP(多层感知机)

在 ConvMixer 模型中:

使用 Depthwise Convolution(逐通道卷积) 来提取 token 间的相关信息,类似 MLP Mixer 中的 token-mixing MLP

使用 Pointwise Convolution(1x1 卷积) 来提取 channel 间的相关信息,类似 MLP Mixer 中的 channel-mixing MLP

然后将两种卷积交替执行,混合两个维度的信息

模型的大致架构如下图所示:

ConvMixer:Patches are all you need? - 游乐网

代码实现

模型的代码实现其实在上面的结构图中已经有出现了,不过由于过于精简可能比较不好理解下面给出最新代码中的另一种常规一些的实现方式,结构比较清晰,并且手动添加了一些注释,相对比较好理解

模型搭建

In [1]
import paddle.nn as nnclass Residual(nn.Layer):    # Residual Block(残差层)    # y = f(x) + x    def __init__(self, fn):        super().__init__()        self.fn = fn    def forward(self, x):        return self.fn(x) + xdef ConvMixer(dim, depth, kernel_size=9, patch_size=7, act=nn.GELU, n_classes=1000):    # ConvMixer Model    # dim: hidden channal dim(ConvMixer 网络的隐藏层通道数)    # depth: num of ConvMixer Block(网络层数也是其中 ConvMixer 层的数量)    # kernel_size: kernel_size of Convolution in ConvMixer Block(ConvMixer 层中的卷积层的卷积核大小)    # patch_size: patch_size in Patch Embedding (Patch Embedding 时 Patch 的大小)    # act: activate function(激活函数)    # n_classes: num of classes(输出的类别数量)    return nn.Sequential(        # Patch Embedding        # Conv(kernel_size = stride = patch_size) + GELU + BN        # 使用一个卷积核大小和步长都等于 Patch 大小的卷积层进行输入图像 Embedding 的操作        # 并连接一个 GELU 激活函数和 BN 批归一化层        nn.Conv2D(3, dim, kernel_size=patch_size, stride=patch_size),        act(),        nn.BatchNorm2D(dim),        # ConvMixer Block x N(depth)        # N(depth) 个 ConvMixer 层        *[nn.Sequential(            # Residual Block + Depthwise Convolution + GELU + BN            # 逐通道卷积提取 Token 之间的信息            # 并连接一个 GELU 激活函数和 BN 批归一化层            # 最后与原输入进行一个残差连接            Residual(nn.Sequential(                nn.Conv2D(dim, dim, kernel_size, groups=dim, padding="same"),                act(),                nn.BatchNorm2D(dim)            )),            # Pointwise Convolution + GELU + BN            # 1x1 卷积提取 Channel 之间的信息            # 并连接一个 GELU 激活函数和 BN 批归一化层            nn.Conv2D(dim, dim, kernel_size=1),            act(),            nn.BatchNorm2D(dim)        ) for i in range(depth)],        # Output Layers        nn.AdaptiveAvgPool2D((1,1)),        nn.Flatten(),        nn.Linear(dim, n_classes)    )
登录后复制

预设模型

目前最新提供了如下三个预训练模型的参数文件In [2]
import paddledef convmixer_1536_20(pretrained=False, **kwargs):    model = ConvMixer(1536, 20, kernel_size=9, patch_size=7, **kwargs)    if pretrained:        params = paddle.load('/home/aistudio/data/data111600/convmixer_1536_20_ks9_p7.pdparams')        model.set_dict(params)    return modeldef convmixer_1024_20(pretrained=False, **kwargs):    model = ConvMixer(1024, 20, kernel_size=9, patch_size=14, **kwargs)    if pretrained:        params = paddle.load('/home/aistudio/data/data111600/convmixer_1024_20_ks9_p14.pdparams')        model.set_dict(params)    return modeldef convmixer_768_32(pretrained=False, **kwargs):    model = ConvMixer(768, 32, kernel_size=7, patch_size=7, act=nn.ReLU, **kwargs)    if pretrained:        params = paddle.load('/home/aistudio/data/data111600/convmixer_768_32_ks7_p7_relu.pdparams')        model.set_dict(params)    return model
登录后复制

模型测试

In [3]
model = convmixer_768_32(pretrained=True)x = paddle.randn((1, 3, 224, 224))out = model(x)print(out.shape)model.eval()out = model(x)print(out.shape)
登录后复制

精度测试

标称精度

ConvMixer 与其他一些先进模型的精度对比:

ConvMixer:Patches are all you need? - 游乐网

具体的精度表现如下表:

ConvMixer:Patches are all you need? - 游乐网

解压数据集

解压 ImageNet 1k 验证集In [8]
!mkdir data/ILSVRC2012
登录后复制In [9]
!tar -xf ~/data/data68594/ILSVRC2012_img_val.tar -C ~/data/ILSVRC2012
登录后复制

精度验证

使用 ImageNet 1k 验证集对模型进行精度验证可以看到结果与最新给出的基本一致In [4]
import osimport cv2import numpy as npimport paddleimport paddle.vision.transforms as Tfrom PIL import Image# 构建数据集class ILSVRC2012(paddle.io.Dataset):    def __init__(self, root, label_list, transform, backend='pil'):        self.transform = transform        self.root = root        self.label_list = label_list        self.backend = backend        self.load_datas()    def load_datas(self):        self.imgs = []        self.labels = []        with open(self.label_list, 'r') as f:            for line in f:                img, label = line[:-1].split(' ')                self.imgs.append(os.path.join(self.root, img))                self.labels.append(int(label))    def __getitem__(self, idx):        label = self.labels[idx]        image = self.imgs[idx]        if self.backend=='cv2':            image = cv2.imread(image)        else:            image = Image.open(image).convert('RGB')        image = self.transform(image)        return image.astype('float32'), np.array(label).astype('int64')    def __len__(self):        return len(self.imgs)val_transforms = T.Compose([    T.Resize(int(224 / 0.96), interpolation='bicubic'),    T.CenterCrop(224),    T.ToTensor(),    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 配置模型model = convmixer_1536_20(pretrained=True)model = paddle.Model(model)model.prepare(metrics=paddle.metric.Accuracy(topk=(1, 5)))# 配置数据集val_dataset = ILSVRC2012('data/ILSVRC2012', transform=val_transforms, label_list='data/data68594/val_list.txt', backend='pil')# 模型验证acc = model.evaluate(val_dataset, batch_size=128, num_workers=0, verbose=1)print(acc)
登录后复制
Eval begin...step 391/391 [==============================] - acc_top1: 0.8137 - acc_top5: 0.9562 - 3s/step          Eval samples: 50000{'acc_top1': 0.81366, 'acc_top5': 0.95616}
登录后复制

模型训练

从头训练

根据论文的模型配置训练一下 CIFAR-10 数据集的 BaseLine:

ConvMixer:Patches are all you need? - 游乐网

由于没有严格对齐各项训练参数,所以训练结果可能应该会有差异

In [ ]
import osimport cv2import numpy as npimport paddleimport paddle.nn as nnimport paddle.vision.transforms as Tfrom paddle.vision.datasets import Cifar10from PIL import Imagefrom paddle.callbacks import EarlyStopping, VisualDL, ModelCheckpointtrain_transforms = T.Compose([    T.Resize(int(224 / 0.96), interpolation='bicubic'),    T.RandomCrop(224),    T.ToTensor(),    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])val_transforms = T.Compose([    T.Resize(int(224 / 0.96), interpolation='bicubic'),    T.CenterCrop(224),    T.ToTensor(),    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])model = ConvMixer(256, 8)opt = paddle.optimizer.Adam(learning_rate=1e-5, parameters=model.parameters())model = paddle.Model(model)model.prepare(optimizer=opt, loss=nn.CrossEntropyLoss(), metrics=paddle.metric.Accuracy(topk=(1, 5)))train_dataset = Cifar10(transform=train_transforms, backend='pil', mode='train')val_dataset = Cifar10(transform=val_transforms, backend='pil', mode='test')checkpoint = ModelCheckpoint(save_dir='save')earlystopping = EarlyStopping(monitor='acc_top1',                                mode='max',                                patience=3,                                verbose=1,                                min_delta=0,                                baseline=None,                                save_best_model=True)vdl = VisualDL('log')model.fit(train_dataset, val_dataset, batch_size=32, num_workers=0, epochs=10, save_dir='save', callbacks=[checkpoint, earlystopping, vdl], verbose=1)
登录后复制

微调训练

基于预训练模型在 Cifar10 数据集上进行微调训练In [ ]
import osimport cv2import numpy as npimport paddleimport paddle.nn as nnimport paddle.vision.transforms as Tfrom paddle.vision.datasets import Cifar10from PIL import Imagefrom paddle.callbacks import EarlyStopping, VisualDL, ModelCheckpointtrain_transforms = T.Compose([    T.Resize(int(224 / 0.96), interpolation='bicubic'),    T.RandomCrop(224),    T.ToTensor(),    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])val_transforms = T.Compose([    T.Resize(int(224 / 0.96), interpolation='bicubic'),    T.CenterCrop(224),    T.ToTensor(),    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])model = convmixer_768_32(n_classes=10, pretrained=True)opt = paddle.optimizer.Adam(learning_rate=1e-5, parameters=model.parameters())model = paddle.Model(model)model.prepare(optimizer=opt, loss=nn.CrossEntropyLoss(), metrics=paddle.metric.Accuracy(topk=(1, 5)))train_dataset = Cifar10(transform=train_transforms, backend='pil', mode='train')val_dataset = Cifar10(transform=val_transforms, backend='pil', mode='test')checkpoint = ModelCheckpoint(save_dir='save')earlystopping = EarlyStopping(monitor='acc_top1',                                mode='max',                                patience=3,                                verbose=1,                                min_delta=0,                                baseline=None,                                save_best_model=True)vdl = VisualDL('log')model.fit(train_dataset, val_dataset, batch_size=32, num_workers=0, epochs=1, save_dir='save', callbacks=[checkpoint, earlystopping, vdl], verbose=1)
登录后复制
来源:https://www.php.cn/faq/1413612.html
免责声明: 游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。

相关攻略

阿里千问 AI 眼镜接入蚂蚁 GPASS:语音解锁共享单车、停车缴费
AI
阿里千问 AI 眼镜接入蚂蚁 GPASS:语音解锁共享单车、停车缴费

当AI眼镜学会“跑腿”:语音解锁单车,无感支付停车费 近来,智能穿戴领域的一个新动向值得关注:阿里旗下的千问AI眼镜,正式接入了蚂蚁集团的GPASS平台。这可不是一次简单的功能叠加,它意味着,诸如共享单车骑行、停车缴费这一系列高频的“AI办事”功能,开始从手机屏幕转移到了你的眼前。 简单说,借助GP

热心网友
04.06
Workbuddy注册额外积分
AI
Workbuddy注册额外积分

角色定位与核心任务目标 明确了基本定位后,我们直接切入核心:作为一名专业的文章优化师,我的核心职责在于,将那些带有明显AI生成特征的文本,深度重塑为拥有个人特色与行业洞见的优质内容。 换句话说,这项任务的关键在于实施一次“精准的换血手术”。你必须严格保证原文所有的事实依据、核心观点、逻辑框架,以及每

热心网友
04.06
OpenClaw使用kimi web_search返回401问题
AI
OpenClaw使用kimi web_search返回401问题

1 故障现象:OpenClaw无法联网搜索的典型报错 许多开发者在配置OpenClaw AI助手的搜索功能时,常常会遭遇一个典型故障:日常对话交互完全正常,但一旦触发需要联网查询信息的指令,界面便会立刻弹出“抱歉,我目前无法使用网络搜索功能(需要配置 API 密钥)”或“HTTP 401: Inv

热心网友
04.05
1.4 万亿词元!阿里 Qwen3.6-Plus 刷新全球最大 AI 聚合平台 OpenRouter 日调用量纪录
AI
1.4 万亿词元!阿里 Qwen3.6-Plus 刷新全球最大 AI 聚合平台 OpenRouter 日调用量纪录

1 4 万亿词元!阿里 Qwen3 6-Plus 刷新全球最大 AI 聚合平台 OpenRouter 日调用量纪录 这事儿挺震撼的。就在4月4日,全球最大的AI模型聚合平台OpenRouter在其官方账号上公布了一个爆炸性数字:阿里刚刚发布的千问新模型Qwen3 6-Plus,上线仅仅一天,日调用量

热心网友
04.04
Solidus Ai Tech(AITECH)币是什么?怎么样?AITECH工作原理和代币经济学概述
web3.0
Solidus Ai Tech(AITECH)币是什么?怎么样?AITECH工作原理和代币经济学概述

Solidus AI 是什么 在AI与Web3加速融合的当下,一个名为Solidus AI的项目提出了自己的解决方案。它将自己定位为“Web3原生的AI HPC基础设施”,其蓝图相当清晰:以位于欧洲的环保高性能计算(HPC)数据中心为基石,向上构建一个计算与AI工具市场,并最终通过AITECH代币完

热心网友
04.03

最新APP

火柴人传奇
火柴人传奇
动作冒险 04-01
街球艺术
街球艺术
体育竞技 04-01
飞行员模拟
飞行员模拟
休闲益智 04-01
史莱姆农场
史莱姆农场
休闲益智 04-01
绝区零
绝区零
角色扮演 04-01

热门推荐

《洛克王国世界》独角兽伊利斯叫什么-呼唤独角兽的名字怎么写的
游戏攻略
《洛克王国世界》独角兽伊利斯叫什么-呼唤独角兽的名字怎么写的

《洛克王国世界》呼唤独角兽的正确姿势 在《洛克王国世界》的主线任务中,有时会遇到需要精确输入特定角色名称的环节。其中一个关键节点,便是要准确拼写出独角兽“伊利斯”的真名。很多玩家稍不注意就可能记错或用错字,导致任务流程在此停滞不前。这篇指南将为你清晰解析正确的输入方法,助你快速通关。 《洛克王国世界

热心网友
04.06
《洛克王国世界》找到向上的方法任务怎么做-风眠圣所找到向上的方法任务图文攻略
游戏攻略
《洛克王国世界》找到向上的方法任务怎么做-风眠圣所找到向上的方法任务图文攻略

《洛克王国世界》风眠圣所“向上的方法”任务图文通关指南 在《洛克王国世界》的风眠圣所探险过程中,很多玩家会在“找到向上的方法”这一环节遭遇卡点。实际上,只要理清思路、明确顺序,完成这个挑战并不困难。本攻略将为你提供一套经过验证的详细图文流程,帮助你一次性顺利通过。 最后的关键操作非常简单:准确判断风

热心网友
04.06
《洛克王国世界》叶冕魔力猫怎么打-叶冕魔力猫打法技巧攻略
游戏攻略
《洛克王国世界》叶冕魔力猫怎么打-叶冕魔力猫打法技巧攻略

《洛克王国世界》叶冕魔力猫打法全攻略:高效通关技巧解析 在《洛克王国世界》的主线剧情推进中,挑战初始精灵首领叶冕魔力猫是一个重要环节。许多玩家在这个关卡遇到了困难,感觉难以突破。不必担心,这份详尽的实战打法指南将为你提供清晰的过关思路,帮助你轻松击败叶冕魔力猫。 核心挑战思路与强力精灵推荐 与叶冕魔

热心网友
04.06
《洛克王国世界》罗隐在哪里抓-罗隐捕捉位置图解
游戏攻略
《洛克王国世界》罗隐在哪里抓-罗隐捕捉位置图解

《洛克王国世界》罗隐捕捉指南:高效获取圣羽翼王挑战关键战宠 在《洛克王国世界》中,成功挑战传说精灵圣羽翼王是许多训练师的终极目标之一。选择合适的战宠至关重要,而罗隐以其出色的对抗能力,已成为公认的核心攻略选择。那么,这只关键的宠物究竟在哪里可以捕获?本文将为你提供详尽的罗隐捕捉位置图解与实用技巧。

热心网友
04.06
大店小二元宝与银两优先使用攻略-资源合理分配技巧
游戏攻略
大店小二元宝与银两优先使用攻略-资源合理分配技巧

速览 在《大店小二》中,如何高效使用元宝和银两是新手玩家普遍面临的难题。资源有限,如何将每一分投入转化为最大收益?本文将深入解析两类资源的最优使用策略,核心原则是:元宝投资于长期价值,银两专注于核心养成。 大店小二元宝与银两使用优先级攻略 1 元宝使用指南 首要建议:若非充值玩家,请勿将元宝大量用

热心网友
04.06