首页 游戏 软件 资讯 排行榜 专题
首页
AI
基于Attention U-Net的宠物图像分割

基于Attention U-Net的宠物图像分割

热心网友
49
转载
2025-07-22
本文基于《Attention U-Net: Learning Where to Look for the Pancreas》,实现了用于宠物图像分割的Attention U-Net模型。通过划分数据集,构建含注意力门的网络结构,用RMSProp优化器和交叉熵损失训练,经15轮后在测试集上预测,结果展示了模型对宠物图像的分割效果,验证了其有效性。

基于attention u-net的宠物图像分割 - 游乐网

基于Attention U-Net的宠物图像分割

论文:Attention U-Net: Learning Where to Look for the Pancreas

免费影视、动漫、音乐、游戏、小说资源长期稳定更新! 👉 点此立即查看 👈

简介

首次在医学图像的CNN中使用Soft Attention,该模块可以替代分类任务中的Hard attention和器官定位任务中的定位模块。Attention U-Net是一种新的用于医学成像的注意门(AG)模型,该模型自动学习聚焦于不同形状和大小的目标结构。隐含地学习抑制输入图像中不相关的区域,同时突出对特定任务有用的显著特征。Attention模块只需很小的计算开销,同时提高了模型的灵敏度和预测精度。

效果

基于Attention U-Net的宠物图像分割 - 游乐网        

模型结构

基于Attention U-Net的宠物图像分割 - 游乐网        

Attention Gate模块

Attention的意思是,把注意力放到目标区域上,简单来说就是让目标区域的值变大。Attention模块用在了skip connection上,原始U-Net只是单纯的把同层的下采样层的特征直接concate到上采样层中,改进后的使用attention模块对下采样层同层和上采样层上一层的特征图进行处理后再和上采样后的特征图进行concate

基于Attention U-Net的宠物图像分割 - 游乐网        

环境设置

In [1]
import osimport ioimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image as PilImageimport paddleimport paddle.nn as nnimport paddle.nn.functional as Fpaddle.set_device('gpu')paddle.__version__
登录后复制        
'2.1.0'
登录后复制                

数据处理

此处数据处理部分借鉴了『跟着雨哥学AI』系列06:趣味案例——基于U-Net的宠物图像分割

In [2]
# 解压缩!tar -xf data/data50154/images.tar.gz!tar -xf data/data50154/annotations.tar.gz
登录后复制    In [3]
IMAGE_SIZE = (160, 160)train_images_path = "images/"label_images_path = "annotations/trimaps/"image_count = len([os.path.join(train_images_path, image_name)           for image_name in os.listdir(train_images_path)           if image_name.endswith('.webp')])print("用于训练的图片样本数量:", image_count)# 对数据集进行处理,划分训练集、测试集def _sort_images(image_dir, image_type):    """    对文件夹内的图像进行按照文件名排序    """    files = []    for image_name in os.listdir(image_dir):        if image_name.endswith('.{}'.format(image_type)) \                and not image_name.startswith('.'):            files.append(os.path.join(image_dir, image_name))    return sorted(files)def write_file(mode, images, labels):    with open('./{}.txt'.format(mode), 'w') as f:        for i in range(len(images)):            f.write('{}\t{}\n'.format(images[i], labels[i]))    images = _sort_images(train_images_path, 'jpg')labels = _sort_images(label_images_path, 'png')eval_num = int(image_count * 0.15)write_file('train', images[:-eval_num], labels[:-eval_num])write_file('test', images[-eval_num:], labels[-eval_num:])write_file('predict', images[-eval_num:], labels[-eval_num:])
登录后复制        
用于训练的图片样本数量: 7390
登录后复制        In [4]
with open('./train.txt', 'r') as f:    i = 0    for line in f.readlines():        image_path, label_path = line.strip().split('\t')        image = np.array(PilImage.open(image_path))        label = np.array(PilImage.open(label_path))            if i > 2:            break        # 进行图片的展示        plt.figure()        plt.subplot(1,2,1),         plt.title('Train Image')        plt.imshow(image.astype('uint8'))        plt.axis('off')        plt.subplot(1,2,2),         plt.title('Label')        plt.imshow(label.astype('uint8'), cmap='gray')        plt.axis('off')        plt.show()        i = i + 1
登录后复制        
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working  if isinstance(obj, collections.Iterator):/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working  return list(data) if isinstance(data, collections.MappingView) else data/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead  a_min = np.asscalar(a_min.astype(scaled_dtype))/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead  a_max = np.asscalar(a_max.astype(scaled_dtype))
登录后复制        
登录后复制登录后复制登录后复制                
登录后复制登录后复制登录后复制                
登录后复制登录后复制登录后复制                

数据集类定义

In [5]
import randomfrom paddle.io import Datasetfrom paddle.vision.transforms import transforms as Tclass PetDataset(Dataset):    """    数据集定义    """    def __init__(self, mode='train'):        """        构造函数        """        self.image_size = IMAGE_SIZE        self.mode = mode.lower()                assert self.mode in ['train', 'test', 'predict'], \            "mode should be 'train' or 'test' or 'predict', but got {}".format(self.mode)                self.train_images = []        self.label_images = []        with open('./{}.txt'.format(self.mode), 'r') as f:            for line in f.readlines():                image, label = line.strip().split('\t')                self.train_images.append(image)                self.label_images.append(label)            def _load_img(self, path, color_mode='rgb', transforms=[]):        """        统一的图像处理接口封装,用于规整图像大小和通道        """        with open(path, 'rb') as f:            img = PilImage.open(io.BytesIO(f.read()))            if color_mode == 'grayscale':                # if image is not already an 8-bit, 16-bit or 32-bit grayscale image                # convert it to an 8-bit grayscale image.                if img.mode not in ('L', 'I;16', 'I'):                    img = img.convert('L')            elif color_mode == 'rgba':                if img.mode != 'RGBA':                    img = img.convert('RGBA')            elif color_mode == 'rgb':                if img.mode != 'RGB':                    img = img.convert('RGB')            else:                raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')                        return T.Compose([                T.Resize(self.image_size)            ] + transforms)(img)    def __getitem__(self, idx):        """        返回 image, label        """        train_image = self._load_img(self.train_images[idx],                                      transforms=[                                         T.Transpose(),                                          T.Normalize(mean=127.5, std=127.5)                                     ]) # 加载原始图像        label_image = self._load_img(self.label_images[idx],                                      color_mode='grayscale',                                     transforms=[T.Grayscale()]) # 加载Label图像            # 返回image, label        train_image = np.array(train_image, dtype='float32')        label_image = np.array(label_image, dtype='int64')        return train_image, label_image            def __len__(self):        """        返回数据集总数        """        return len(self.train_images)
登录后复制    

模型组网

基础模块

In [6]
class conv_block(nn.Layer):    def __init__(self, ch_in, ch_out):        super(conv_block, self).__init__()        self.conv = nn.Sequential(            nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1),            nn.BatchNorm(ch_out),            nn.ReLU(),            nn.Conv2D(ch_out, ch_out, kernel_size=3, stride=1, padding=1),            nn.BatchNorm(ch_out),            nn.ReLU()        )    def forward(self, x):        x = self.conv(x)        return xclass up_conv(nn.Layer):    def __init__(self, ch_in, ch_out):        super(up_conv, self).__init__()        self.up = nn.Sequential(            nn.Upsample(scale_factor=2),            nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1),            nn.BatchNorm(ch_out),            nn.ReLU()        )    def forward(self, x):        x = self.up(x)        return xclass single_conv(nn.Layer):    def __init__(self, ch_in, ch_out):        super(single_conv, self).__init__()        self.conv = nn.Sequential(            nn.Conv2D(ch_in, ch_out, kernel_size=3, stride=1, padding=1),            nn.BatchNorm(ch_out),            nn.ReLU()        )    def forward(self, x):        x = self.conv(x)        return x
登录后复制    

Attention块

In [7]
class Attention_block(nn.Layer):    def __init__(self, F_g, F_l, F_int):        super(Attention_block, self).__init__()        self.W_g = nn.Sequential(            nn.Conv2D(F_g, F_int, kernel_size=1, stride=1, padding=0),            nn.BatchNorm(F_int)        )        self.W_x = nn.Sequential(            nn.Conv2D(F_l, F_int, kernel_size=1, stride=1, padding=0),            nn.BatchNorm(F_int)        )        self.psi = nn.Sequential(            nn.Conv2D(F_int, 1, kernel_size=1, stride=1, padding=0),            nn.BatchNorm(1),            nn.Sigmoid()        )        self.relu = nn.ReLU()    def forward(self, g, x):        g1 = self.W_g(g)        x1 = self.W_x(x)        psi = self.relu(g1 + x1)        psi = self.psi(psi)        return x * psi
登录后复制    

Attention U-Net

In [9]
class AttU_Net(nn.Layer):    def __init__(self, img_ch=3, output_ch=1):        super(AttU_Net, self).__init__()        self.Maxpool = nn.MaxPool2D(kernel_size=2, stride=2)        self.Maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)        self.Maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)        self.Maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)        self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)        self.Conv2 = conv_block(ch_in=64, ch_out=128)        self.Conv3 = conv_block(ch_in=128, ch_out=256)        self.Conv4 = conv_block(ch_in=256, ch_out=512)        self.Conv5 = conv_block(ch_in=512, ch_out=1024)        self.Up5 = up_conv(ch_in=1024, ch_out=512)        self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)        self.Up4 = up_conv(ch_in=512, ch_out=256)        self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)        self.Up3 = up_conv(ch_in=256, ch_out=128)        self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)        self.Up2 = up_conv(ch_in=128, ch_out=64)        self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)        self.Conv_1x1 = nn.Conv2D(64, output_ch, kernel_size=1, stride=1, padding=0)    def forward(self, x):        # encoding path        x1 = self.Conv1(x)        x2 = self.Maxpool(x1)        x2 = self.Conv2(x2)        x3 = self.Maxpool1(x2)        x3 = self.Conv3(x3)        x4 = self.Maxpool2(x3)        x4 = self.Conv4(x4)        x5 = self.Maxpool3(x4)        x5 = self.Conv5(x5)        # decoding + concat path        d5 = self.Up5(x5)        x4 = self.Att5(g=d5, x=x4)        d5 = paddle.concat(x=[x4, d5], axis=1)        d5 = self.Up_conv5(d5)        d4 = self.Up4(d5)        x3 = self.Att4(g=d4, x=x3)        d4 = paddle.concat(x=[x3, d4], axis=1)        d4 = self.Up_conv4(d4)        d3 = self.Up3(d4)        x2 = self.Att3(g=d3, x=x2)        d3 = paddle.concat(x=[x2, d3], axis=1)        d3 = self.Up_conv3(d3)        d2 = self.Up2(d3)        x1 = self.Att2(g=d2, x=x1)        d2 = paddle.concat(x=[x1, d2], axis=1)        d2 = self.Up_conv2(d2)        d1 = self.Conv_1x1(d2)        return d1
登录后复制    

模型可视化

In [10]
num_classes = 4network = AttU_Net(img_ch=3, output_ch=num_classes)model = paddle.Model(network)model.summary((-1, 3,) + IMAGE_SIZE)
登录后复制        
-----------------------------------------------------------------------------  Layer (type)        Input Shape          Output Shape         Param #    =============================================================================    Conv2D-1       [[1, 3, 160, 160]]   [1, 64, 160, 160]        1,792        BatchNorm-1    [[1, 64, 160, 160]]   [1, 64, 160, 160]         256           ReLU-1       [[1, 64, 160, 160]]   [1, 64, 160, 160]          0           Conv2D-2      [[1, 64, 160, 160]]   [1, 64, 160, 160]       36,928        BatchNorm-2    [[1, 64, 160, 160]]   [1, 64, 160, 160]         256           ReLU-2       [[1, 64, 160, 160]]   [1, 64, 160, 160]          0         conv_block-1     [[1, 3, 160, 160]]   [1, 64, 160, 160]          0          MaxPool2D-1    [[1, 64, 160, 160]]    [1, 64, 80, 80]           0           Conv2D-3       [[1, 64, 80, 80]]     [1, 128, 80, 80]       73,856        BatchNorm-3     [[1, 128, 80, 80]]    [1, 128, 80, 80]         512           ReLU-3        [[1, 128, 80, 80]]    [1, 128, 80, 80]          0           Conv2D-4       [[1, 128, 80, 80]]    [1, 128, 80, 80]       147,584       BatchNorm-4     [[1, 128, 80, 80]]    [1, 128, 80, 80]         512           ReLU-4        [[1, 128, 80, 80]]    [1, 128, 80, 80]          0         conv_block-2     [[1, 64, 80, 80]]     [1, 128, 80, 80]          0          MaxPool2D-2     [[1, 128, 80, 80]]    [1, 128, 40, 40]          0           Conv2D-5       [[1, 128, 40, 40]]    [1, 256, 40, 40]       295,168       BatchNorm-5     [[1, 256, 40, 40]]    [1, 256, 40, 40]        1,024          ReLU-5        [[1, 256, 40, 40]]    [1, 256, 40, 40]          0           Conv2D-6       [[1, 256, 40, 40]]    [1, 256, 40, 40]       590,080       BatchNorm-6     [[1, 256, 40, 40]]    [1, 256, 40, 40]        1,024          ReLU-6        [[1, 256, 40, 40]]    [1, 256, 40, 40]          0         conv_block-3     [[1, 128, 40, 40]]    [1, 256, 40, 40]          0          MaxPool2D-3     [[1, 256, 40, 40]]    [1, 256, 20, 20]          0           Conv2D-7       [[1, 256, 20, 20]]    [1, 512, 20, 20]      1,180,160      BatchNorm-7     [[1, 512, 20, 20]]    [1, 512, 20, 20]        2,048          ReLU-7        [[1, 512, 20, 20]]    [1, 512, 20, 20]          0           Conv2D-8       [[1, 512, 20, 20]]    [1, 512, 20, 20]      2,359,808      BatchNorm-8     [[1, 512, 20, 20]]    [1, 512, 20, 20]        2,048          ReLU-8        [[1, 512, 20, 20]]    [1, 512, 20, 20]          0         conv_block-4     [[1, 256, 20, 20]]    [1, 512, 20, 20]          0          MaxPool2D-4     [[1, 512, 20, 20]]    [1, 512, 10, 10]          0           Conv2D-9       [[1, 512, 10, 10]]   [1, 1024, 10, 10]      4,719,616      BatchNorm-9    [[1, 1024, 10, 10]]   [1, 1024, 10, 10]        4,096          ReLU-9       [[1, 1024, 10, 10]]   [1, 1024, 10, 10]          0           Conv2D-10     [[1, 1024, 10, 10]]   [1, 1024, 10, 10]      9,438,208     BatchNorm-10    [[1, 1024, 10, 10]]   [1, 1024, 10, 10]        4,096          ReLU-10      [[1, 1024, 10, 10]]   [1, 1024, 10, 10]          0         conv_block-5     [[1, 512, 10, 10]]   [1, 1024, 10, 10]          0          Upsample-1     [[1, 1024, 10, 10]]   [1, 1024, 20, 20]          0           Conv2D-11     [[1, 1024, 20, 20]]    [1, 512, 20, 20]      4,719,104     BatchNorm-11     [[1, 512, 20, 20]]    [1, 512, 20, 20]        2,048          ReLU-11       [[1, 512, 20, 20]]    [1, 512, 20, 20]          0           up_conv-1     [[1, 1024, 10, 10]]    [1, 512, 20, 20]          0           Conv2D-12      [[1, 512, 20, 20]]    [1, 256, 20, 20]       131,328      BatchNorm-12     [[1, 256, 20, 20]]    [1, 256, 20, 20]        1,024         Conv2D-13      [[1, 512, 20, 20]]    [1, 256, 20, 20]       131,328      BatchNorm-13     [[1, 256, 20, 20]]    [1, 256, 20, 20]        1,024          ReLU-12       [[1, 256, 20, 20]]    [1, 256, 20, 20]          0           Conv2D-14      [[1, 256, 20, 20]]     [1, 1, 20, 20]          257        BatchNorm-14      [[1, 1, 20, 20]]      [1, 1, 20, 20]           4           Sigmoid-1       [[1, 1, 20, 20]]      [1, 1, 20, 20]           0       Attention_block-1          []            [1, 512, 20, 20]          0           Conv2D-15     [[1, 1024, 20, 20]]    [1, 512, 20, 20]      4,719,104     BatchNorm-15     [[1, 512, 20, 20]]    [1, 512, 20, 20]        2,048          ReLU-13       [[1, 512, 20, 20]]    [1, 512, 20, 20]          0           Conv2D-16      [[1, 512, 20, 20]]    [1, 512, 20, 20]      2,359,808     BatchNorm-16     [[1, 512, 20, 20]]    [1, 512, 20, 20]        2,048          ReLU-14       [[1, 512, 20, 20]]    [1, 512, 20, 20]          0         conv_block-6    [[1, 1024, 20, 20]]    [1, 512, 20, 20]          0          Upsample-2      [[1, 512, 20, 20]]    [1, 512, 40, 40]          0           Conv2D-17      [[1, 512, 40, 40]]    [1, 256, 40, 40]      1,179,904     BatchNorm-17     [[1, 256, 40, 40]]    [1, 256, 40, 40]        1,024          ReLU-15       [[1, 256, 40, 40]]    [1, 256, 40, 40]          0           up_conv-2      [[1, 512, 20, 20]]    [1, 256, 40, 40]          0           Conv2D-18      [[1, 256, 40, 40]]    [1, 128, 40, 40]       32,896       BatchNorm-18     [[1, 128, 40, 40]]    [1, 128, 40, 40]         512          Conv2D-19      [[1, 256, 40, 40]]    [1, 128, 40, 40]       32,896       BatchNorm-19     [[1, 128, 40, 40]]    [1, 128, 40, 40]         512           ReLU-16       [[1, 128, 40, 40]]    [1, 128, 40, 40]          0           Conv2D-20      [[1, 128, 40, 40]]     [1, 1, 40, 40]          129        BatchNorm-20      [[1, 1, 40, 40]]      [1, 1, 40, 40]           4           Sigmoid-2       [[1, 1, 40, 40]]      [1, 1, 40, 40]           0       Attention_block-2          []            [1, 256, 40, 40]          0           Conv2D-21      [[1, 512, 40, 40]]    [1, 256, 40, 40]      1,179,904     BatchNorm-21     [[1, 256, 40, 40]]    [1, 256, 40, 40]        1,024          ReLU-17       [[1, 256, 40, 40]]    [1, 256, 40, 40]          0           Conv2D-22      [[1, 256, 40, 40]]    [1, 256, 40, 40]       590,080      BatchNorm-22     [[1, 256, 40, 40]]    [1, 256, 40, 40]        1,024          ReLU-18       [[1, 256, 40, 40]]    [1, 256, 40, 40]          0         conv_block-7     [[1, 512, 40, 40]]    [1, 256, 40, 40]          0          Upsample-3      [[1, 256, 40, 40]]    [1, 256, 80, 80]          0           Conv2D-23      [[1, 256, 80, 80]]    [1, 128, 80, 80]       295,040      BatchNorm-23     [[1, 128, 80, 80]]    [1, 128, 80, 80]         512           ReLU-19       [[1, 128, 80, 80]]    [1, 128, 80, 80]          0           up_conv-3      [[1, 256, 40, 40]]    [1, 128, 80, 80]          0           Conv2D-24      [[1, 128, 80, 80]]    [1, 64, 80, 80]         8,256       BatchNorm-24     [[1, 64, 80, 80]]     [1, 64, 80, 80]          256          Conv2D-25      [[1, 128, 80, 80]]    [1, 64, 80, 80]         8,256       BatchNorm-25     [[1, 64, 80, 80]]     [1, 64, 80, 80]          256           ReLU-20       [[1, 64, 80, 80]]     [1, 64, 80, 80]           0           Conv2D-26      [[1, 64, 80, 80]]      [1, 1, 80, 80]          65         BatchNorm-26      [[1, 1, 80, 80]]      [1, 1, 80, 80]           4           Sigmoid-3       [[1, 1, 80, 80]]      [1, 1, 80, 80]           0       Attention_block-3          []            [1, 128, 80, 80]          0           Conv2D-27      [[1, 256, 80, 80]]    [1, 128, 80, 80]       295,040      BatchNorm-27     [[1, 128, 80, 80]]    [1, 128, 80, 80]         512           ReLU-21       [[1, 128, 80, 80]]    [1, 128, 80, 80]          0           Conv2D-28      [[1, 128, 80, 80]]    [1, 128, 80, 80]       147,584      BatchNorm-28     [[1, 128, 80, 80]]    [1, 128, 80, 80]         512           ReLU-22       [[1, 128, 80, 80]]    [1, 128, 80, 80]          0         conv_block-8     [[1, 256, 80, 80]]    [1, 128, 80, 80]          0          Upsample-4      [[1, 128, 80, 80]]   [1, 128, 160, 160]         0           Conv2D-29     [[1, 128, 160, 160]]  [1, 64, 160, 160]       73,792       BatchNorm-29    [[1, 64, 160, 160]]   [1, 64, 160, 160]         256           ReLU-23      [[1, 64, 160, 160]]   [1, 64, 160, 160]          0           up_conv-4      [[1, 128, 80, 80]]   [1, 64, 160, 160]          0           Conv2D-30     [[1, 64, 160, 160]]   [1, 32, 160, 160]        2,080       BatchNorm-30    [[1, 32, 160, 160]]   [1, 32, 160, 160]         128          Conv2D-31     [[1, 64, 160, 160]]   [1, 32, 160, 160]        2,080       BatchNorm-31    [[1, 32, 160, 160]]   [1, 32, 160, 160]         128           ReLU-24      [[1, 32, 160, 160]]   [1, 32, 160, 160]          0           Conv2D-32     [[1, 32, 160, 160]]    [1, 1, 160, 160]         33         BatchNorm-32     [[1, 1, 160, 160]]    [1, 1, 160, 160]          4           Sigmoid-4      [[1, 1, 160, 160]]    [1, 1, 160, 160]          0       Attention_block-4          []           [1, 64, 160, 160]          0           Conv2D-33     [[1, 128, 160, 160]]  [1, 64, 160, 160]       73,792       BatchNorm-33    [[1, 64, 160, 160]]   [1, 64, 160, 160]         256           ReLU-25      [[1, 64, 160, 160]]   [1, 64, 160, 160]          0           Conv2D-34     [[1, 64, 160, 160]]   [1, 64, 160, 160]       36,928       BatchNorm-34    [[1, 64, 160, 160]]   [1, 64, 160, 160]         256           ReLU-26      [[1, 64, 160, 160]]   [1, 64, 160, 160]          0         conv_block-9    [[1, 128, 160, 160]]  [1, 64, 160, 160]          0           Conv2D-35     [[1, 64, 160, 160]]    [1, 4, 160, 160]         260      =============================================================================Total params: 34,894,392Trainable params: 34,863,144Non-trainable params: 31,248-----------------------------------------------------------------------------Input size (MB): 0.29Forward/backward pass size (MB): 563.67Params size (MB): 133.11Estimated Total Size (MB): 697.07-----------------------------------------------------------------------------
登录后复制        
{'total_params': 34894392, 'trainable_params': 34863144}
登录后复制                

模型训练

In [11]
train_dataset = PetDataset(mode='train') # 训练数据集val_dataset = PetDataset(mode='test') # 验证数据集optim = paddle.optimizer.RMSProp(learning_rate=0.001,                                  rho=0.9,                                  momentum=0.0,                                  epsilon=1e-07,                                  centered=False,                                 parameters=model.parameters())model.prepare(optim, paddle.nn.CrossEntropyLoss(axis=1))model.fit(train_dataset,           val_dataset,           epochs=15,           batch_size=32,          verbose=1)
登录后复制    

模型预测

In [12]
predict_dataset = PetDataset(mode='predict')predict_results = model.predict(predict_dataset)
登录后复制        
Predict begin...step 1108/1108 [==============================] - 20ms/step         Predict samples: 1108
登录后复制        In [13]
plt.figure(figsize=(10, 10))i = 0mask_idx = 0with open('./predict.txt', 'r') as f:    for line in f.readlines():        image_path, label_path = line.strip().split('\t')        resize_t = T.Compose([            T.Resize(IMAGE_SIZE)        ])        image = resize_t(PilImage.open(image_path))        label = resize_t(PilImage.open(label_path))        image = np.array(image).astype('uint8')        label = np.array(label).astype('uint8')        if i > 8:             break        plt.subplot(3, 3, i + 1)        plt.imshow(image)        plt.title('Input Image')        plt.axis("off")        plt.subplot(3, 3, i + 2)        plt.imshow(label, cmap='gray')        plt.title('Label')        plt.axis("off")                # 模型只有一个输出,通过predict_results[0]来取出1000个预测的结果        # 映射原始图片的index来取出预测结果,提取mask进行展示        data = predict_results[0][mask_idx][0].transpose((1, 2, 0))        mask = np.argmax(data, axis=-1)        plt.subplot(3, 3, i + 3)        plt.imshow(mask.astype('uint8'), cmap='gray')        plt.title('Predict')        plt.axis("off")        i += 3        mask_idx += 1plt.show()
登录后复制        
登录后复制                
来源:https://www.php.cn/faq/1422072.html
免责声明: 游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。

相关攻略

CodeGeeX怎么写Python网络扫描代码_CodeGeeX辅助实现端口扫描逻辑【网络扫描】
AI
CodeGeeX怎么写Python网络扫描代码_CodeGeeX辅助实现端口扫描逻辑【网络扫描】

一、使用socket模块逐端口连接检测 想从最基础、最轻量的方法入手?Python标准库里的socket模块是个不错的起点。它通过尝试建立TCP连接来判断端口状态,无需任何外部依赖,适合快速验证或小范围探测。 具体操作起来很简单:在你安装了CodeGeeX插件的IDE(比如VS Code)里新建一个

热心网友
04.17
如何在 Python 中对符号向量进行平方运算(如计算模长平方)
编程语言
如何在 Python 中对符号向量进行平方运算(如计算模长平方)

如何在 Python 中对符号向量进行平方运算(如计算模长平方) 在科学计算与工程建模领域,处理符号向量时,一个常见且易混淆的操作便是“向量平方”。需要明确的是,在符号计算中,“向量平方”通常并非指对每个分量进行平方,而是指计算其模长的平方(即 $ mathbf{M}^ top mathbf{M}

热心网友
04.17
如何计算多品种混合仓位的“相关性系数”?
web3.0
如何计算多品种混合仓位的“相关性系数”?

多品种混合仓位相关性系数全解析:四种实战计算法提升投资组合效能 在Web3投资领域,无论是管理一篮子加密货币、NFT资产还是DeFi头寸,构建一个稳健的多品种混合仓位已成为专业投资者的标配。然而,许多人在优化组合时,往往过度关注单个资产的预期回报,却忽略了决定整体风险的关键指标——相关性系数。这个介

热心网友
04.17
Python使用正则表达式将多个空格替换为一个空格
编程语言
Python使用正则表达式将多个空格替换为一个空格

方法一:使用 re sub() 替换连续空白字符 在Python文本处理中,字符串内包含多余的空格、制表符或换行符是一个常见问题。利用Python内置的re sub()函数可以高效解决。其核心原理是使用正则表达式匹配所有连续的空白字符序列,并将其统一替换为单个空格,从而实现文本规范化。 import

热心网友
04.17
Toga,一个超精简的 Python 项目!
业界动态
Toga,一个超精简的 Python 项目!

Toga:一套代码,跑遍所有平台的原生GUI方案 用Python开发图形界面,一个长久以来的理想是:写一次代码,就能在Windows、macOS乃至移动设备上原生运行。现在,有一个框架正朝着这个目标扎实迈进——它就是BeeWare家族的核心成员,Toga。它的承诺很吸引人:“写一次,跑遍所有平台”,

热心网友
04.17

最新APP

宝宝过生日
宝宝过生日
应用辅助 04-07
台球世界
台球世界
体育竞技 04-07
解绳子
解绳子
休闲益智 04-07
骑兵冲突
骑兵冲突
棋牌策略 04-07
三国真龙传
三国真龙传
角色扮演 04-07

热门推荐

全链网:预计2026年GDP增长率为2%-2.5%
web3.0
全链网:预计2026年GDP增长率为2%-2.5%

美联储2026年经济展望:2%-2 5%增长区间下的市场与Web3新机遇 近日,美联储重要官员威廉姆斯释放了关于美国经济长期走势的关键信号,引发全球市场广泛关注。根据其最新预测,到2026年,美国GDP年增长率预计将维持在2%至2 5%的区间。这一表述不仅为传统金融市场提供了清晰的长期锚点,也为正处

热心网友
04.17
boss直聘怎么删除好友
手机教程
boss直聘怎么删除好友

在BOSS直聘上如何删除好友?详细操作指南 使用BOSS直聘时,偶尔会需要清理一下人脉列表,比如与某些联系人不再有交集,或者想精简自己的社交关系。那么,具体该如何删除好友呢?这个过程其实并不复杂,但有几个关键步骤和注意事项需要留心。下面就来手把手带你走一遍流程。 第一步:进入个人中心 首先,确保你已

热心网友
04.17
ddos防御平台 对比指南:不同方案优缺点分析
网络安全
ddos防御平台 对比指南:不同方案优缺点分析

DDoS攻击威胁的演变趋势与当前挑战随着企业数字化转型加速和关键业务全面线上化,分布式拒绝服务攻击已发展成为最具普遍性及破坏性的网络安全威胁之一。回顾其演变历程,早期攻击规模有限,多依赖单一源头或小型僵尸网络发起;而现代DDoS攻击则呈现出巨型化、复杂化、精准化的新特征。攻击者不仅利用海量物联网设备

热心网友
04.17
利用BEANFUN元件解决游戏中常见问题的方法与策略
游戏攻略
利用BEANFUN元件解决游戏中常见问题的方法与策略

BEANFUN元件:游戏启动与管理的核心工具在畅玩众多线上游戏时,一个稳定可靠的启动与管理组件是获得流畅体验的基础。对于广大玩家来说,BEANFUN元件正是这样一把关键的“钥匙”,它集成了账号登录、游戏启动、安全保护及社区服务等核心功能。其稳定运行直接关系到玩家能否顺利进入游戏世界。深入了解BEAN

热心网友
04.17
苹果手机无法设置面容ID Face ID无法设置解决办法
iphone
苹果手机无法设置面容ID Face ID无法设置解决办法

一、清洁原深感摄像头区域并排除物理遮挡 面容ID能否顺利设置,第一步往往就藏在细节里。那个位于屏幕顶部“刘海”内的原深感摄像头组件,其实是个精密的光学系统,包含了红外点阵投影器、泛光感应元件和红外镜头。任何一点微尘、油渍,甚至是一张不合规的贴膜,都可能干扰红外光的投射与接收,导致面部建模失败。所以,

热心网友
04.17