当前位置: 首页 > AI > 文章内容页

轻量级Vision-Transformer:EdgeViTs复现

时间:2025-07-29    作者:游乐小编    

本文聚焦轻量级Vision-Transformer模型EdgeViTs的复现。EdgeViTs为适配移动设备,采用分层金字塔结构,设计Local-Global-Local(LGL)瓶颈,通过局部聚合、全局稀疏注意力和局部传播操作,在减少计算量的同时保留全局与局部上下文信息。文中给出模型各组件及整体架构的Paddle实现代码,并基于Flowers数据集进行训练验证。

轻量级Vision-Transformer:EdgeViTs复现 - 游乐网

轻量级Vision-Transformer:EdgeViTs复现

摘要

  在计算机视觉领域,基于Self-attention的模型(如(ViTs))已经成为CNN之外的一种极具竞争力的架构。尽管越来越强的变种具有越来越高的识别精度,但由于Self-attention的二次复杂度,现有的ViT在计算和模型大小方面都有较高的要求。 虽然之前的CNN的一些成功的设计选择(例如,卷积和分层结构)已经被引入到最近的ViT中,但它们仍然不足以满足移动设备有限的计算资源需求。这促使人们最近尝试开发基于最先进的MobileNet-v2的轻型MobileViT,但MobileViT与MobileNet-v2仍然存在性能差距。 在这项工作中,作者进一步推进这一研究方向,引入了EdgeViTs,一个新的轻量级ViTs家族,也是首次使基于Self-attention的视觉模型在准确性和设备效率之间的权衡中达到最佳轻量级CNN的性能。

1 EdgeViTs

1.1 总体架构

  为了设计适用于移动/边缘设备的轻量级ViT,作者采用了最近ViT变体中使用的分层金字塔结构(图2(a))。Pyramid Transformer模型通常在不同阶段降低了空间分辨率同时也扩展了通道维度。每个阶段由多个基于Transformer Block处理相同形状的张量,类似ResNet的层次设计结构。

  在这项工作中,作者深入到Transformer Block,并引入了一个比较划算的Bottlneck,Local-Global-Local(LGL)(图2(b))。LGL通过一个稀疏注意力模块进一步减少了Self-attention的开销(图2(c)),实现了更好的准确性-延迟平衡。

轻量级Vision-Transformer:EdgeViTs复现 - 游乐网        

1.2 Local-Global-Local bottleneck(LGL)

  与以前在每个空间位置执行Self-attention的Transformer Block相比,LGL Bottleneck只对输入Token的子集计算Self-attention,但支持完整的空间交互,如在标准的Multi-Head Self-attention(MHSA)中。既会减少Token的作用域,同时也保留建模全局和局部上下文的底层信息流。

  为了实现这一点,作者将Self-attention分解为连续的模块,处理不同范围内的空间Token(图2(b))。

  这里引入了3种有效的操作:

  • Local aggregation:仅集成来自局部近似Token信号的局部聚合
  • Global sparse attention:建模一组代表性Token之间的长期关系,其中每个Token都被视为一个局部窗口的代表;
  • Local propagation:将委托学习到的全局上下文信息扩散到具有相同窗口的非代表Token。

轻量级Vision-Transformer:EdgeViTs复现 - 游乐网        

  • Local aggregation

  对于每个Token,利用Depth-wise和Point-wise卷积在大小为k×k的局部窗口中聚合信息(图3(a))。

  • Global sparse attention

  对均匀分布在空间中的稀疏代表性Token集进行采样,每个r×r窗口有一个代表性Token。这里,r表示子样本率。然后,只对这些被选择的Token应用Self-attention(图3(b))。这与所有现有的ViTs不同,在那里,所有的空间Token都作为Self-attention计算中的query被涉及到。

  • Local propagation

  通过转置卷积将代表性 Token 中编码的全局上下文信息传播到它们的相邻的 Token 中(图 3(c))。

轻量级Vision-Transformer:EdgeViTs复现 - 游乐网        

2 代码复现

In [1]

import paddleimport paddle.nn as nnfrom paddle.nn import Conv2D  as Conv2dfrom paddle.nn import BatchNorm2D  as BatchNorm2dfrom paddle.nn import Linearfrom paddle.nn import AvgPool2D as AvgPool2dfrom paddle.nn import Conv2DTranspose as ConvTranspose2dfrom paddle.nn import LayerNorm, GELU
   

In [2]

class Residual(nn.Layer):    def __init__(self, module):        super().__init__()        self.module = module        def forward(self, x):        return x + self.module(x)class LocalAgg(nn.Layer):      def __init__(self, dim):        super().__init__()        self.conv1 = Conv2d(dim, dim, 1)          self.conv2 = Conv2d(dim, dim, 3, padding=1, groups=dim)          self.conv3 = Conv2d(dim, dim, 1)          self.norm1 = BatchNorm2d(dim)          self.norm2 = BatchNorm2d(dim)                def forward(self, x):          """          [B, C, H, W] = x.shape          """          x = self.conv1(self.norm1(x))          x = self.conv2(x)          x = self.conv3(self.norm2(x))          return x  class GlobalSparseAttn(nn.Layer):      def __init__(self, dim, sample_rate = 4, scale = 1):        super().__init__()          self.head_dim = int(48)//int(1)        self.num_heads = int(1)        self.scale = scale          self.qkv = Linear(dim, dim * 3)          self.sampler = AvgPool2d(1, stride=sample_rate)          self.LocalProp = ConvTranspose2d(dim, dim, kernel_size=sample_rate, stride=sample_rate, groups=dim          )          self.proj = Linear(dim, dim)      def forward(self, x):          """          [B, C, H, W] = x.shape          """          x = self.sampler(x)        [B, C, H, W] = x.shape        x = x.flatten(2)        x = x.transpose([0,2,1])        x = self.qkv(x)        x = x.transpose([0, 2, 1])        x = x.reshape([1, 144, 14, 14])        q, k, v = x.reshape([B, self.num_heads, -1, H*W]).split([self.head_dim, self.head_dim, self.head_dim], axis=2)               attn = (q.transpose([0, 1, 3, 2]) @ k)        attn = nn.functional.softmax(attn)        x = v  @  attn.transpose([0, 1, 3, 2])        x = x.reshape([B, -1, H, W])        x = self.LocalProp(x)                 x = paddle.nn.functional.layer_norm(x, x.shape[1:])        x = x.flatten(2)        x = x.transpose([0,2,1])        x = self.proj(x)          x = x.transpose([0,2,1])        x = x.reshape([1, 48, 56, 56])        return x  class DownSampleLayer(nn.Layer):      def __init__(self, dim_in=3, dim_out=48, downsample_rate=4):          super().__init__()        self.downsample = Conv2d(dim_in, dim_out, kernel_size=downsample_rate, stride=          downsample_rate)      def forward(self, x):          x = self.downsample(x)        x = paddle.nn.functional.layer_norm(x, x.shape[1:])        return x  class PatchEmbed(nn.Layer):      def __init__(self, dim):        super().__init__()        self.embed = Conv2d(dim, dim, 3, padding=1, groups=dim)      def forward(self, x):          return x + self.embed(x)  class FFN(nn.Layer):      def __init__(self, dim=3156):        super().__init__()          self.fc1 = nn.Linear(dim, dim*4)          self.fc2 = nn.Linear(dim*4, dim)                def forward(self, x):        x = x.flatten(2)        x = x.transpose([0,2,1])               x = self.fc1(x)          x = nn.functional.gelu(x)         x = self.fc2(x)                x = x.transpose([0,2,1])        x = x.reshape([1, 48, 56, 56])        return x
   

In [ ]

class EdgeViT(nn.Layer):    def __init__(self, dim_in=3, dim_out=48, downsample_rate=4, dim=48):        super().__init__()               self.downsample1 = DownSampleLayer(dim_in=3, dim_out=48, downsample_rate=4)        self.patchembeding1 = PatchEmbed(dim=48)        self.residual_add1 = Residual(LocalAgg(dim=48))        self.residual_add1_1 = Residual(FFN(dim=48))        self.patchembeding2 = PatchEmbed(dim=48)        self.residual_add2 = Residual(GlobalSparseAttn(dim=48))        self.fc = nn.Linear(150528,103)    def forward(self, x):        x = self.downsample1(x)        x = self.patchembeding1(x)        x = self.residual_add1(x)        x = self.residual_add1_1(x)        x = self.patchembeding2(x)        x = self.residual_add2(x)        x = paddle.reshape(x,shape=[-1,48*56*56])        # x = x.transpose([0,2,1])        # print(x.shape)        x = self.fc(x)        return x
   

In [4]

cnn = EdgeViT()paddle.summary(cnn,(1,3,224,224))
       

[1, 150528]------------------------------------------------------------------------------   Layer (type)        Input Shape          Output Shape         Param #    ==============================================================================     Conv2D-1       [[1, 3, 224, 224]]    [1, 48, 56, 56]         2,352     DownSampleLayer-1   [[1, 3, 224, 224]]    [1, 48, 56, 56]           0            Conv2D-2       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480         PatchEmbed-1     [[1, 48, 56, 56]]     [1, 48, 56, 56]           0         BatchNorm2D-1     [[1, 48, 56, 56]]     [1, 48, 56, 56]          192           Conv2D-3       [[1, 48, 56, 56]]     [1, 48, 56, 56]         2,352          Conv2D-4       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480        BatchNorm2D-2     [[1, 48, 56, 56]]     [1, 48, 56, 56]          192           Conv2D-5       [[1, 48, 56, 56]]     [1, 48, 56, 56]         2,352         LocalAgg-1      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0           Residual-1      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0            Linear-1        [[1, 3136, 48]]       [1, 3136, 192]         9,408          Linear-2        [[1, 3136, 192]]      [1, 3136, 48]          9,264           FFN-1         [[1, 48, 56, 56]]     [1, 48, 56, 56]           0           Residual-2      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0            Conv2D-6       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480         PatchEmbed-2     [[1, 48, 56, 56]]     [1, 48, 56, 56]           0          AvgPool2D-1      [[1, 48, 56, 56]]     [1, 48, 14, 14]           0            Linear-3         [[1, 196, 48]]       [1, 196, 144]          7,056     Conv2DTranspose-1   [[1, 48, 14, 14]]     [1, 48, 56, 56]          816           Linear-4        [[1, 3136, 48]]       [1, 3136, 48]          2,352     GlobalSparseAttn-1  [[1, 48, 56, 56]]     [1, 48, 56, 56]           0           Residual-3      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0            Linear-5         [[1, 150528]]           [1, 103]         15,504,487   ==============================================================================Total params: 15,542,263Trainable params: 15,541,879Non-trainable params: 384------------------------------------------------------------------------------Input size (MB): 0.57Forward/backward pass size (MB): 27.85Params size (MB): 59.29Estimated Total Size (MB): 87.71------------------------------------------------------------------------------
       

{'total_params': 15542263, 'trainable_params': 15541879}
               

3 模型训练

  论文的实验是基于ImageNet数据集进行的,但是目前平台不具备拉取该数据集的能力,故这里采用了Cifar10作为模型验证数据集,仅做调通,不设置对比实验,因为在小数据集上无对比性。

In [5]

import paddlefrom paddle.vision.datasets import Flowersfrom paddle.vision.transforms import Compose, Normalize, Resize, Transpose, ToTensornormalize = Normalize(mean=[0.5, 0.5, 0.5],                    std=[0.5, 0.5, 0.5],                    data_format='HWC')transform = Compose([ToTensor(), Normalize(), Resize(size=(224,224))])cifar10_train = paddle.vision.datasets.Flowers(mode='train',                                               transform=transform)cifar10_test = paddle.vision.datasets.Flowers(mode='test',                                              transform=transform)# 构建训练集数据加载器train_loader = paddle.io.DataLoader(cifar10_train, batch_size=1, shuffle=True)# 构建测试集数据加载器test_loader = paddle.io.DataLoader(cifar10_test, batch_size=1, shuffle=True)print('=============train dataset=============')for image, label in cifar10_train:    print('image shape: {}, label: {}'.format(image.shape, label))    break
       

=============train dataset=============image shape: [3, 224, 224], label: [1]
       

In [ ]

from paddle.metric import Accuracymodel = paddle.Model(EdgeViT())optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())model.prepare(    optim,    paddle.nn.CrossEntropyLoss(),    Accuracy()    )model.fit(train_data=train_loader,        eval_data=test_loader,        epochs=2,        verbose=1        )
   

热门推荐

更多

热门文章

更多

首页  返回顶部

本站所有软件都由网友上传,如有侵犯您的版权,请发邮件youleyoucom@outlook.com