时间:2025-07-18 作者:游乐小编
本文复现了ConViT模型,其通过GPSA模块将CNN的归纳偏置引入ViT。代码用Paddle实现,包含网络结构搭建、模型定义等。在Cifar10数据集验证,因结合卷积优点,少样本下性能优于DeiT。还提供预训练权重,ImageNet验证集上不同架构有对应精度。
In this paper, we take a new step towards bridging the gap between CNNs and Transformers, by presenting a new method to “softly" introduce a convolutional inductive bias into the ViT
paper:https://arxiv.org/abs/2103.10697
code:https://github.com/facebookresearch/convit
Hi guy,我们又见面了,这次来复现ConViT,最新性能如下
卷积神经网络具有归纳偏置,使得训练可以节约样本,但是缺点是模型天花板低,当数据集小时候,CNN展现比ViT更好的性能,当数据集充足时候,ViT展现比CNN更好的性能,基于此本文提出GPSA模块,将CNN具有的归纳偏置带入ViT,在ImageNet上取得了比DeiT更好的性能
网络结构图如下
import paddleimport paddle.nn as nnimport paddle.nn.functional as Ffrom functools import partialimport numpy as np登录后复制
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations def convert_to_list(value, n, name, dtype=np.int):登录后复制
zeros_ = nn.initializer.Constant(value=0.)ones_ = nn.initializer.Constant(value=1.)trunc_normal_ = nn.initializer.TruncatedNormal(std=.02)def to_2tuple(x): return tuple([x] * 2)def drop_path(x, drop_prob = 0., training = False): if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = paddle.to_tensor(keep_prob) + paddle.rand(shape) random_tensor = paddle.floor(random_tensor) output = x.divide(keep_prob) * random_tensor return outputclass DropPath(nn.Layer): def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training)class Identity(nn.Layer): def __init__(self, *args, **kwargs): super(Identity, self).__init__() def forward(self, input): return inputclass Mlp(nn.Layer): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return xclass PatchEmbed(nn.Layer): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else Identity() def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) if self.flatten: x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC x = self.norm(x) return xclass HybridEmbed(nn.Layer): def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.backbone = backbone if feature_size is None: with paddle.no_grad(): training = backbone.training if training: backbone.eval() o = self.backbone(paddle.zeros([1, in_chans, img_size[0], img_size[1]])) if isinstance(o, (list, tuple)): o = o[-1] feature_dim = o.shape[1] backbone.train(training) else: feature_size = to_2tuple(feature_size) if hasattr(self.backbone, 'feature_info'): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] self.proj = nn.Conv2D(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.backbone(x) if isinstance(x, (list, tuple)): x = x[-1] x = self.proj(x).flatten(2).transpose([0, 2, 1]) return xdef repeat(x, rep): return paddle.to_tensor(np.tile(x.numpy(), rep))def repeat_interleave(x, rep, axis): return paddle.to_tensor(np.repeat(x.numpy(), rep, axis=axis))def einsum(str, distances, attn_map): d = distances.numpy() a = attn_map.numpy() out = np.einsum(str, (d, a)) return paddle.to_tensor(out)登录后复制
GPSA
class GPSA(nn.Layer): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., locality_strength=1., use_local_init=True): super().__init__() self.num_heads = num_heads self.dim = dim head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qk = nn.Linear(dim, dim * 2, bias_attr=qkv_bias) self.v = nn.Linear(dim, dim, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.pos_proj = nn.Linear(3, num_heads) self.proj_drop = nn.Dropout(proj_drop) self.locality_strength = locality_strength self.gating_param = self.create_parameter(shape=[self.num_heads], default_initializer=ones_) self.add_parameter("gating_param", self.gating_param) def forward(self, x): B, N, C = x.shape if not hasattr(self, 'rel_indices') or self.rel_indices.shape[1]!=N: self.get_rel_indices(N) attn = self.get_attention(x) v = self.v(x).reshape([B, N, self.num_heads, C // self.num_heads]).transpose([0, 2, 1, 3]) x = (attn @ v).transpose([0, 2, 1, 3]) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x def get_attention(self, x): B, N, C = x.shape qk = self.qk(x).reshape([B, N, 2, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4]) q, k = qk[0], qk[1] pos_score = self.rel_indices.expand([B, -1, -1,-1]) pos_score = self.pos_proj(pos_score).transpose([0,3,1,2]) patch_score = (q @ k.transpose([0, 1, 3, 2])) * self.scale patch_score = F.softmax(patch_score, axis=-1) pos_score = F.softmax(pos_score, axis=-1) gating = self.gating_param.reshape([1, -1, 1, 1]) attn = (1. - F.sigmoid(gating)) * patch_score + F.sigmoid(gating) * pos_score attn /= attn.sum(axis=-1).unsqueeze(-1) attn = self.attn_drop(attn) return attn def get_attention_map(self, x, return_map = False): attn_map = self.get_attention(x).mean(0) distances = self.rel_indices.squeeze()[:,:,-1]**.5 dist = einsum('nm,hnm->h', distances, attn_map) # einsum dist /= distances.shape[0] if return_map: return dist, attn_map else: return dist def get_rel_indices(self, num_patches): img_size = int(num_patches**.5) rel_indices = paddle.zeros([1, num_patches, num_patches, 3]) ind = paddle.arange(img_size).reshape([1,-1]) - paddle.arange(img_size).reshape([-1, 1]) indx = repeat(ind, [img_size, img_size]) indy = repeat_interleave(ind, img_size, axis=0) indy = repeat_interleave(indy, img_size, axis=1) indd = indx**2 + indy**2 rel_indices[:,:,:,2] = indd.unsqueeze(0) rel_indices[:,:,:,1] = indy.unsqueeze(0) rel_indices[:,:,:,0] = indx.unsqueeze(0) self.rel_indices = rel_indices def local_init(self): self.v.weight.set_value(paddle.eye(self.dim)) locality_distance = 1 # max(1,1/locality_strength**.5) kernel_size = int(self.num_heads ** .5) center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 for h1 in range(kernel_size): for h2 in range(kernel_size): position = h1 + kernel_size * h2 self.pos_proj.weight[2, position] = -1 self.pos_proj.weight[1, position] = 2 * (h1 - center) * locality_distance self.pos_proj.weight[0, position] = 2 * (h2 - center) * locality_distance self.pos_proj.weight.set_value(self.pos_proj.weight * self.locality_strength)class MHSA(nn.Layer): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def get_attention_map(self, x, return_map = False): B, N, C = x.shape qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4]) q, k, v = qkv[0], qkv[1], qkv[2] attn_map = (q @ k.transpose([0, 1, 3, 2])) * self.scale attn_map = F.softmax(attn_map, axis=-1).mean(0) img_size = int(N**.5) ind = paddle.arange(img_size).reshape([1,-1]) - paddle.arange(img_size).reshape([-1, 1]) indx = repeat(ind, [img_size, img_size]) indy = repeat_interleave(ind, img_size, axis=0) indy = repeat_interleave(indy, img_size, axis=1) indd = indx**2 + indy**2 distances = indd**.5 dist = einsum('nm,hnm->h', distances, attn_map) # einsum dist /= N if return_map: return dist, attn_map else: return dist def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4]) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale attn = F.softmax(attn, axis=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose([0,2,1,3]).reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Layer): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa if self.use_gpsa: self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) else: self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class VisionTransformer(nn.Layer): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=48, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, local_up_to_layer=10, locality_strength=1., use_pos_embed=True): super().__init__() embed_dim *= num_heads self.num_classes = num_classes self.local_up_to_layer = local_up_to_layer self.num_features = self.embed_dim = embed_dim self.use_pos_embed = use_pos_embed if hybrid_backbone is not None: self.patch_embed = HybridEmbed( hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.num_patches = num_patches self.cls_token = self.create_parameter(shape=[1, 1, embed_dim], default_initializer=nn.initializer.TruncatedNormal(mean=0.0, std=.02)) self.add_parameter("cls_token", self.cls_token) self.pos_drop = nn.Dropout(p=drop_rate) if self.use_pos_embed: self.pos_embed = self.create_parameter(shape=[1, num_patches, embed_dim], default_initializer=nn.initializer.TruncatedNormal(mean=0.0, std=.02)) self.add_parameter("pos_embed", self.pos_embed) dpr = [x for x in paddle.linspace(0, drop_path_rate, depth)] self.blocks = nn.LayerList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_gpsa=True, locality_strength=locality_strength) if i登录后复制0 else Identity() self.apply(self._init_weights) for n, m in self.named_sublayers(): if hasattr(m, 'local_init'): m.local_init() def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: zeros_(m.bias) elif isinstance(m, nn.LayerNorm): zeros_(m.bias) ones_(m.weight) def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand([B, -1, -1]) if self.use_pos_embed: x = x + self.pos_embed x = self.pos_drop(x) for u,blk in enumerate(self.blocks): if u == self.local_up_to_layer : x = paddle.concat((cls_tokens, x), axis=1) x = blk(x) x = self.norm(x) return x[:, 0] def forward(self, x): x = self.forward_features(x) x = self.head(x) return x
def convit_tiny(**kwargs): model = VisionTransformer( num_heads=4, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs) return modeldef convit_small(**kwargs): model = VisionTransformer( num_heads=9, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs) return modeldef convit_base(**kwargs): model = VisionTransformer( num_heads=16, norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs) return model登录后复制
paddle.Model(convit_base()).summary((1, 3, 224, 224))登录后复制
--------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =========================================================================== Conv2D-1 [[1, 3, 224, 224]] [1, 768, 14, 14] 590,592 Identity-1 [[1, 196, 768]] [1, 196, 768] 0 PatchEmbed-1 [[1, 3, 224, 224]] [1, 196, 768] 0 Dropout-1 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-1 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-1 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-4 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-2 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-2 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-3 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-3 [[1, 196, 768]] [1, 196, 768] 0 GPSA-1 [[1, 196, 768]] [1, 196, 768] 16 Identity-2 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-2 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-5 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-1 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-4 [[1, 196, 768]] [1, 196, 768] 0 Linear-6 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-1 [[1, 196, 768]] [1, 196, 768] 0 Block-1 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-3 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-7 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-10 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-5 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-8 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-9 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-6 [[1, 196, 768]] [1, 196, 768] 0 GPSA-2 [[1, 196, 768]] [1, 196, 768] 16 Identity-3 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-4 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-11 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-2 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-7 [[1, 196, 768]] [1, 196, 768] 0 Linear-12 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-2 [[1, 196, 768]] [1, 196, 768] 0 Block-2 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-5 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-13 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-16 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-8 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-14 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-15 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-9 [[1, 196, 768]] [1, 196, 768] 0 GPSA-3 [[1, 196, 768]] [1, 196, 768] 16 Identity-4 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-6 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-17 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-3 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-10 [[1, 196, 768]] [1, 196, 768] 0 Linear-18 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-3 [[1, 196, 768]] [1, 196, 768] 0 Block-3 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-7 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-19 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-22 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-11 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-20 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-21 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-12 [[1, 196, 768]] [1, 196, 768] 0 GPSA-4 [[1, 196, 768]] [1, 196, 768] 16 Identity-5 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-8 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-23 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-4 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-13 [[1, 196, 768]] [1, 196, 768] 0 Linear-24 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-4 [[1, 196, 768]] [1, 196, 768] 0 Block-4 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-9 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-25 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-28 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-14 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-26 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-27 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-15 [[1, 196, 768]] [1, 196, 768] 0 GPSA-5 [[1, 196, 768]] [1, 196, 768] 16 Identity-6 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-10 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-29 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-5 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-16 [[1, 196, 768]] [1, 196, 768] 0 Linear-30 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-5 [[1, 196, 768]] [1, 196, 768] 0 Block-5 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-11 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-31 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-34 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-17 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-32 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-33 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-18 [[1, 196, 768]] [1, 196, 768] 0 GPSA-6 [[1, 196, 768]] [1, 196, 768] 16 Identity-7 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-12 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-35 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-6 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-19 [[1, 196, 768]] [1, 196, 768] 0 Linear-36 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-6 [[1, 196, 768]] [1, 196, 768] 0 Block-6 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-13 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-37 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-40 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-20 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-38 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-39 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-21 [[1, 196, 768]] [1, 196, 768] 0 GPSA-7 [[1, 196, 768]] [1, 196, 768] 16 Identity-8 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-14 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-41 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-7 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-22 [[1, 196, 768]] [1, 196, 768] 0 Linear-42 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-7 [[1, 196, 768]] [1, 196, 768] 0 Block-7 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-15 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-43 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-46 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-23 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-44 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-45 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-24 [[1, 196, 768]] [1, 196, 768] 0 GPSA-8 [[1, 196, 768]] [1, 196, 768] 16 Identity-9 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-16 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-47 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-8 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-25 [[1, 196, 768]] [1, 196, 768] 0 Linear-48 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-8 [[1, 196, 768]] [1, 196, 768] 0 Block-8 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-17 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-49 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-52 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-26 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-50 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-51 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-27 [[1, 196, 768]] [1, 196, 768] 0 GPSA-9 [[1, 196, 768]] [1, 196, 768] 16 Identity-10 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-18 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-53 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-9 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-28 [[1, 196, 768]] [1, 196, 768] 0 Linear-54 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-9 [[1, 196, 768]] [1, 196, 768] 0 Block-9 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-19 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-55 [[1, 196, 768]] [1, 196, 1536] 1,179,648 Linear-58 [[1, 196, 196, 3]] [1, 196, 196, 16] 64 Dropout-29 [[1, 16, 196, 196]] [1, 16, 196, 196] 0 Linear-56 [[1, 196, 768]] [1, 196, 768] 589,824 Linear-57 [[1, 196, 768]] [1, 196, 768] 590,592 Dropout-30 [[1, 196, 768]] [1, 196, 768] 0 GPSA-10 [[1, 196, 768]] [1, 196, 768] 16 Identity-11 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-20 [[1, 196, 768]] [1, 196, 768] 1,536 Linear-59 [[1, 196, 768]] [1, 196, 3072] 2,362,368 GELU-10 [[1, 196, 3072]] [1, 196, 3072] 0 Dropout-31 [[1, 196, 768]] [1, 196, 768] 0 Linear-60 [[1, 196, 3072]] [1, 196, 768] 2,360,064 Mlp-10 [[1, 196, 768]] [1, 196, 768] 0 Block-10 [[1, 196, 768]] [1, 196, 768] 0 LayerNorm-21 [[1, 197, 768]] [1, 197, 768] 1,536 Linear-61 [[1, 197, 768]] [1, 197, 2304] 1,769,472 Dropout-32 [[1, 16, 197, 197]] [1, 16, 197, 197] 0 Linear-62 [[1, 197, 768]] [1, 197, 768] 590,592 Dropout-33 [[1, 197, 768]] [1, 197, 768] 0 MHSA-1 [[1, 197, 768]] [1, 197, 768] 0 Identity-12 [[1, 197, 768]] [1, 197, 768] 0 LayerNorm-22 [[1, 197, 768]] [1, 197, 768] 1,536 Linear-63 [[1, 197, 768]] [1, 197, 3072] 2,362,368 GELU-11 [[1, 197, 3072]] [1, 197, 3072] 0 Dropout-34 [[1, 197, 768]] [1, 197, 768] 0 Linear-64 [[1, 197, 3072]] [1, 197, 768] 2,360,064 Mlp-11 [[1, 197, 768]] [1, 197, 768] 0 Block-11 [[1, 197, 768]] [1, 197, 768] 0 LayerNorm-23 [[1, 197, 768]] [1, 197, 768] 1,536 Linear-65 [[1, 197, 768]] [1, 197, 2304] 1,769,472 Dropout-35 [[1, 16, 197, 197]] [1, 16, 197, 197] 0 Linear-66 [[1, 197, 768]] [1, 197, 768] 590,592 Dropout-36 [[1, 197, 768]] [1, 197, 768] 0 MHSA-2 [[1, 197, 768]] [1, 197, 768] 0 Identity-13 [[1, 197, 768]] [1, 197, 768] 0 LayerNorm-24 [[1, 197, 768]] [1, 197, 768] 1,536 Linear-67 [[1, 197, 768]] [1, 197, 3072] 2,362,368 GELU-12 [[1, 197, 3072]] [1, 197, 3072] 0 Dropout-37 [[1, 197, 768]] [1, 197, 768] 0 Linear-68 [[1, 197, 3072]] [1, 197, 768] 2,360,064 Mlp-12 [[1, 197, 768]] [1, 197, 768] 0 Block-12 [[1, 197, 768]] [1, 197, 768] 0 LayerNorm-25 [[1, 197, 768]] [1, 197, 768] 1,536 Linear-69 [[1, 768]] [1, 1000] 769,000 ===========================================================================Total params: 86,388,744Trainable params: 86,388,744Non-trainable params: 0---------------------------------------------------------------------------Input size (MB): 0.57Forward/backward pass size (MB): 398.67Params size (MB): 329.55Estimated Total Size (MB): 728.79---------------------------------------------------------------------------登录后复制
{'total_params': 86388744, 'trainable_params': 86388744}登录后复制
采用Cifar10数据集,无过多的数据增强
import paddle.vision.transforms as Tfrom paddle.vision.datasets import Cifar10paddle.set_device('gpu')#数据准备transform = T.Compose([ T.Resize(size=(224,224)), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'), T.ToTensor()])train_dataset = Cifar10(mode='train', transform=transform)val_dataset = Cifar10(mode='test', transform=transform)登录后复制
Cache file /home/aistudio/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz not found, downloading https://dataset.bj.bcebos.com/cifar/cifar-10-python.tar.gz Begin to downloadDownload finished登录后复制
model=paddle.Model(convit_small(num_classes=10))登录后复制
由于时间篇幅只训练6轮,感兴趣的同学可以继续训练
In [10]model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters()), loss=paddle.nn.CrossEntropyLoss(), metrics=paddle.metric.Accuracy())visualdl=paddle.callbacks.VisualDL(log_dir='visual_log') # 开启训练可视化model.fit( train_data=train_dataset, eval_data=val_dataset, batch_size=64, epochs=6, verbose=1, callbacks=[visualdl] )登录后复制
本项目给出了模型预训练权重,在 ImageNet 验证集效果如下
In [ ]# convit tiny model = convit_tiny()model.set_state_dict(paddle.load('data/data93780/convit_tiny.pdparams'))# convit small model = convit_small()model.set_state_dict(paddle.load('data/data93780/convit_small.pdparams'))# convit basemodel = convit_base()model.set_state_dict(paddle.load('data/data93780/convit_base.pdparams'))登录后复制
实验表明,相比DeiT,因为增加了CNN归纳偏置优点,少样本下ConViT性能更好
数据不充分情况下,具有归纳偏置的CNN性能比ViT好,数据充足时候,ViT性能要比CNN好
ConViT结合了卷积归纳偏置优点,但train from scratch问题依旧存在
2021-11-05 11:52
手游攻略2021-11-19 18:38
手游攻略2021-10-31 23:18
手游攻略2022-06-03 14:46
游戏资讯2025-06-28 12:37
单机攻略