【AI达人创造营第二期】 一文读懂双向循环神经网络 BRNN
本文详细介绍了RNN和BRNN的原理,结构和优缺点等。包含了一个使用vanilla RNN,LSTM,BiLSTM,GRU和BiGRU模型分别分类的谣言检测项目。
一、项目简介
本文详细介绍了rnn和brnn的结构,训练时的反向传播,梯度消失和爆炸的原理以及它们的优缺点。包含了一个使用vanilla rnn,lstm,bilstm,gru和bigru模型分别分类的谣言检测项目帮助理解
免费影视、动漫、音乐、游戏、小说资源长期稳定更新! 👉 点此立即查看 👈
二、RNN
2.1 RNN简介
RNN全称为循环神经网络,是一种不同于一般前馈神经网络的特殊神经网络,其旨在处理时间序列数据,假定当前时间步是由先前时间步和当前输入决定的。
RNN维护一个状态变量,用于捕获序列数据中存在的各种模式,因此,它们能够对序列数据建模。并且随时间共享参数集,这也是RNN能学习序列每一时刻模式的主要原因之一。
2.2 RNN的结构
假设已有数据如下所示
x={x1,x2,...,xT}y={y1,y2,...,yT}
假定使用函数逼近器f表示如下两个关系,其中θ与φ表示参数集,ht是当前状态,ht−1是前一状态
ht=f1(xt,ht−1;θ)yt=f2(ht;ϕ)
我们可以将f1,f2的复合看做产生x,y的真正模型的近似,于是得到公式
yt=f2(f1(xt,ht−1;θ);ϕ)
例如y3可以表示为如下公式和图
y3=f2(f1(x3,f2(f1(x2,f2(f1(x1,h0;θ);ϕ);θ);ϕ);θ);ϕ)
通过此可以推出采用循环连接的RNN单步计算
2.3 RNN的技术描述
神经网络通常由一组权重和偏置以及一些激活函数组成,所以上述的ht和yt两个公式可以写成如下形式(形式不唯一,可根据不同任务调整)
ht=tanh(Uxt+Wht−1)y=softmax(Vht)
这里U,W,V是不同的权重矩阵,tanh和softmax是不同的激活函数
2.4 RNN的反向传播
训练RNN需要用到基于时间的反向传播(BPTT),在此我们假设预测误差E,权重矩阵w,样本x的正确标签l,样本x的预测标签y,损失函数L为均方差。
举例:以普通链式法则求∂w3∂E
$$\frac {\partial E}{\partial w_3}={\frac {\partial L}{\partial y}}{\frac {\partial y}{\partial h}}{\frac {\partial h}{\partial w_3}}$$
可以变形为
$$\frac {\partial E}{\partial w_3}={\frac {\partial (y-l)^2}{\partial y}}{\frac {\partial w_2h}{\partial h}}({\frac {\partial (w_1x)}{\partial w_3}+\frac {\partial (w_3h)}{\partial w_3})}$$
而∂w3∂(w3h)这一项会产生问题,因为h是个递归变量且依赖w3,最终会产生无限项。若要解决这一问题,可以将输入序列随时间展开,为每个输入x,创建RNN的副本,并分别计算每个副本的导数,并通过计算梯度的总和将它们回滚,以计算需要更新的权重大小。我们接下来将讨论细节。
计算时考虑完整的输入序列,也就是说要创建4个RNN副本计算直到第四个时间步的所有时间步之和,于是我们可以得到以下结果
$$\frac {\partial E}{\partial w_3}=\sum_{j=1}^3{\frac {\partial L}{\partial y_4}}{\frac {\partial y_4}{\partial h_4}}{\frac {\partial h_4}{\partial h_j}}{\frac {\partial h_j}{\partial w_3}}$$
其中的∂hj∂ht需要t−j+1个副本。然后将副本汇总到单个RNN,求所有先前时间步长的梯度得到一个梯度,最后更新RNN。
然而随着时间步数的增加,这会使得计算变得复杂。所以我们可以使用BPTT的近似,即截断的基于时间的反向传播(TBPTT)。在TBPTT中,我们仅仅计算固定数量的时间步的梯度,即导数只算到t−T,不算到最开始
$$\frac {\partial E}{\partial w_3}=\sum_{j=t-T}^{t-1}{\frac {\partial L}{\partial y_t}}{\frac {\partial y_t}{\partial h_t}}{\frac {\partial h_t}{\partial h_j}}{\frac {\partial h_j}{\partial w_3}}$$
这样随着序列的增长,我们依旧向后计算固定数量的导数,不会使计算成本变得过大。
2.5 RNN的梯度消失与爆炸
在真实的任务训练过程中,RNN存在一个明显的缺陷,那就是当阅读很长的序列时,网络内部的信息会逐渐变得越来越复杂,以至于超过网络的记忆能力,使得最终的输出信息变得混乱无用。
假设标准RNN隐藏状态计算公式为,激活函数σ是sigmoid ,W是权重矩阵
ht=σ(Wxxt+Whht−1)
为了简化计算,忽略与当前输入相关的项,重点放在循环部分
ht=σ(Whht−1)
结合链式法则对ht求偏导得到
∂ht−k∂ht=i=0∏k−1Whσ(Whht−k+i)(1−σ(Whht−k+i))
提取出k个权重矩阵
∂ht−k∂ht=Whki=0∏k−1σ(Whht−k+i)(1−σ(Whht−k+i))
通过公式可以总结出一下特点
当Whht−k+i
前三条需要通过调整RNN结构(LSTM,GRU)来避免,最后一条可以通过调整初始化或者梯度裁剪来规避
2.6 RNN的应用
一对一RNN:适用于每个输入都有输出的问题,例如场景分类,需要对图像中每个像素进行标注。一对多RNN:适用于接收一个输入得到一个输出序列的问题,例如生成图像标题,对于给定的图像可以生成一个文本序列。多对一RNN:适用于输入一个序列,得到一个输出的问题,例如文本分类任务,输入一个句子或者一篇文章判断其类别。多对多RNN:适用于输入一个任意长度序列得到一个任意长度序列的任务,例如机器翻译任务,输入中文文本序列,输入出英文文本序列。三、BRNN
3.1 BRNN简介
BRNN意为双向循环神经网络,是一种常见的提升RNN预测质量的方法。之前介绍RNN时假设了当前时间步是由前面的较早时间步和当前输入决定的。可是当输入序列是文本时句子当前时间步也可能是由后面的时间步决定的,例如:
我晚上不能喝 ___ ,因为要开车回家。
当预测是从头开始读句子,预测出可乐,雪碧,果汁,酒 都合理。因为没有足够的文本信息来帮助预测
而预测时同时从前到后以及从后向前读句子,则能较大概率的预测出 酒 ,因为上下文信息(或许还要结合一些知识库的支持)足够我们进行正确预测。
3.2 BRNN结构
其结构和RNN非常相似,就是两个反向的RNN的叠加,Ht是隐藏状态,由Ht→正向隐藏状态和Ht←反向隐藏状态组成,x为输入,O为输出,W为权重矩阵,b为偏置矩阵。
Ht→=σ(XtWxh(f)+Ht−1→Whh(f)+bh(f))
Ht←=σ(XtWxh(b)+Ht−1←Whh(b)+bh(b))
Ht=(Ht→,Ht←)
Ot=HtWhq+bq
两个方向的隐藏层的连接方式是 concat 拼接起来,蕴含了两个方向的信息 ,根据上图可以想成一个信息往上传,一个信息往下传。不同方向上的隐藏单元个数也可以不同。其反向传播方式和RNN类似,同时这也带来了RNN的一些容易梯度爆炸和梯度消失的缺点,在此不再赘述。
四、RNN以及其变体在项目中的运用
4.1.项目介绍
本项目介绍如何从零开始完成一个谣言检测任务。通过引入微博谣言数据集,基于Paddle框架使用RNN,lstm,bilstm,gru,bigru完成谣言文本的判断。
4.1.1 什么是谣言检测任务
传统的谣言检测模型一般根据谣言的内容、用户属性、传播方式人工地构造特征,而人工构建特征存在考虑片面、浪费人力等现象。本次实践使用基于循环神经网络(RNN)的谣言检测模型,将文本中的谣言事件向量化,通过循环神经网络的学习训练来挖掘表示文本深层的特征,避免了特征构建的问题,并能发现那些不容易被人发现的特征,从而产生更好的效果。
4.2、数据集
4.2.1 数据集概览
本次实践所使用的数据是从新浪微博不实信息举报平台抓取的中文谣言数据,数据集中一共包含1538条谣言和1849条非谣言。如下图所示,每条数据均为json格式,其中text字段代表微博原文的文字内容。
更多数据集介绍请参考https://github.com/thunlp/Chinese_Rumor_Dataset。
4.2.2 数据集处理
(1)解压数据,读取并解析数据,生成all_data.txt
(2)生成数据字典,即dict.txt
(3)生成数据列表,并进行训练集与验证集的划分,train_list.txt 、eval_list.txt
(4)定义训练数据集提供器train_reader和验证数据集提供器eval_reader
In [1]#解压原始数据集,将Rumor_Dataset.zip解压至data目录下import zipfileimport osimport randomfrom PIL import Imagefrom PIL import ImageEnhanceimport json src_path="/home/aistudio/data/data20519/Rumor_Dataset.zip"target_path="/home/aistudio/data/Chinese_Rumor_Dataset-master"if(not os.path.isdir(target_path)): z = zipfile.ZipFile(src_path, 'r') z.extractall(path=target_path) z.close()登录后复制In [2]
#分别为谣言数据、非谣言数据、全部数据的文件路径rumor_class_dirs = os.listdir(target_path+"/Chinese_Rumor_Dataset-master/CED_Dataset/rumor-repost/")non_rumor_class_dirs = os.listdir(target_path+"/Chinese_Rumor_Dataset-master/CED_Dataset/non-rumor-repost/")original_microblog = target_path+"/Chinese_Rumor_Dataset-master/CED_Dataset/original-microblog/"#谣言标签为0,非谣言标签为1rumor_label="0"non_rumor_label="1"#分别统计谣言数据与非谣言数据的总数rumor_num = 0non_rumor_num = 0all_rumor_list = []all_non_rumor_list = []#解析谣言数据for rumor_class_dir in rumor_class_dirs: if(rumor_class_dir != '.DS_Store'): #遍历谣言数据,并解析 with open(original_microblog + rumor_class_dir, 'r') as f: rumor_content = f.read() rumor_dict = json.loads(rumor_content) all_rumor_list.append(rumor_label+"\t"+rumor_dict["text"]+"\n") rumor_num +=1#解析非谣言数据for non_rumor_class_dir in non_rumor_class_dirs: if(non_rumor_class_dir != '.DS_Store'): with open(original_microblog + non_rumor_class_dir, 'r') as f2: non_rumor_content = f2.read() non_rumor_dict = json.loads(non_rumor_content) all_non_rumor_list.append(non_rumor_label+"\t"+non_rumor_dict["text"]+"\n") non_rumor_num +=1 print("谣言数据总量为:"+str(rumor_num))print("非谣言数据总量为:"+str(non_rumor_num))登录后复制谣言数据总量为:1538非谣言数据总量为:1849登录后复制In [3]
#全部数据进行乱序后写入all_data.txtdata_list_path="/home/aistudio/data/"all_data_path=data_list_path + "all_data.txt"all_data_list = all_rumor_list + all_non_rumor_list # 正负样本列表相加random.shuffle(all_data_list) # 打乱list#在生成all_data.txt之前,首先将其清空with open(all_data_path, 'w') as f: f.seek(0) f.truncate() with open(all_data_path, 'a') as f: for data in all_data_list: # 按行写入,一行一个样本 f.write(data)# with open(all_data_path,'r',encoding='UTF-8') as f:# print(f.read())登录后复制In [4]
# 导入必要的包%matplotlib inlineimport osfrom multiprocessing import cpu_countimport numpy as npimport shutilimport paddleimport paddle.fluid as fluidfrom PIL import Imageimport matplotlib.pyplot as pltfrom matplotlib.font_manager import FontPropertiesfont = FontProperties(fname='simhei.ttf', size=16)登录后复制In [5]
# all_data_path= data_list_path + "all_data.txt"# data_list_path="/home/aistudio/data/"# dict_path = data_list_path + "dict.txt"# 生成数据字典def create_dict(data_path, dict_path): dict_set = set() # 创建一个无序不重复元素集,集合 # 读取全部数据 with open(data_path, 'r', encoding='utf-8') as f: lines = f.readlines() # with open('data/test_lines.txt', 'w', encoding='utf-8') as f: # f.write(str(lines)) #print(lines) # 把数据生成一个元组 for line in lines: content = line.split('\t')[-1].replace('\n', '') # 以\t为分隔符,取最后一段。去掉每一行的换行符 # print(line,content) for s in content: dict_set.add(s) # print(s,dict_set) #将每一个字加入到一个无序不重复元素集中(元组) # 把集合转换成字典,一个字对应一个数字 dict_list = [] i = 0 for s in dict_set: dict_list.append([s, i]) i += 1 # print(dict_list) # 添加未知字符 dict_txt = dict(dict_list) # list转字典 end_dict = {"": i} dict_txt.update(end_dict) # 把这些字典保存到本地中 with open(dict_path, 'w', encoding='utf-8') as f: f.write(str(dict_txt)) print("数据字典生成完成!") # 获取字典的长度def get_dict_len(dict_path): with open(dict_path, 'r', encoding='utf-8') as f: line = eval(f.readlines()[0]) return len(line.keys())create_dict('data/all_data.txt','data/dict.txt')get_dict_len('data/dict.txt') 登录后复制数据字典生成完成!登录后复制
4410登录后复制In [6]
# 创建序列化表示的数据,并按照一定比例划分训练数据与验证数据def create_data_list(data_list_path): #在生成数据之前,首先将eval_list.txt和train_list.txt清空 with open(os.path.join(data_list_path, 'eval_list.txt'), 'w', encoding='utf-8') as f_eval: f_eval.seek(0) f_eval.truncate() with open(os.path.join(data_list_path, 'train_list.txt'), 'w', encoding='utf-8') as f_train: f_train.seek(0) f_train.truncate() with open(os.path.join(data_list_path, 'dict.txt'), 'r', encoding='utf-8') as f_data: dict_txt = eval(f_data.readlines()[0]) # print(dict_txt) with open(os.path.join(data_list_path, 'all_data.txt'), 'r', encoding='utf-8') as f_data: lines = f_data.readlines() # print(lines) i = 0 with open(os.path.join(data_list_path, 'eval_list.txt'), 'a', encoding='utf-8') as f_eval,open(os.path.join(data_list_path, 'train_list.txt'), 'a', encoding='utf-8') as f_train: for line in lines: # print(line) words = line.split('\t')[-1].replace('\n', '') label = line.split('\t')[0] labs = "" if i % 8 == 0: for s in words: lab = str(dict_txt[s]) labs = labs + lab + ',' labs = labs[:-1] labs = labs + '\t' + label + '\n' f_eval.write(labs) else: for s in words: lab = str(dict_txt[s]) labs = labs + lab + ',' labs = labs[:-1] labs = labs + '\t' + label + '\n' f_train.write(labs) i += 1 print("数据列表生成完成!")create_data_list('/home/aistudio/data/')登录后复制数据列表生成完成!登录后复制In [7]
#dict_path为数据字典存放路径#all_data_path= data_list_path + "all_data.txt"#data_list_path="/home/aistudio/data/"dict_path = data_list_path + "dict.txt"#创建数据字典,存放位置:dict.txt。在生成之前先清空dict.txtwith open(dict_path, 'w') as f: f.seek(0) f.truncate() create_dict(all_data_path, dict_path)#创建数据列表,存放位置:train_list.txt eval_list.txtcreate_data_list(data_list_path)登录后复制
数据字典生成完成!数据列表生成完成!登录后复制In [8]
def data_mapper(sample): data, label = sample data = [int(data) for data in data.split(',')] return data, int(label)#定义数据读取器def data_reader(data_path): def reader(): with open(data_path, 'r') as f: lines = f.readlines() for line in lines: data, label = line.split('\t') yield data, label return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)登录后复制In [9]# 获取训练数据读取器和测试数据读取器,设置超参数# data_list_path="/home/aistudio/data/"BATCH_SIZE = 256train_list_path = data_list_path+'train_list.txt'eval_list_path = data_list_path+'eval_list.txt'train_reader = paddle.batch(reader=data_reader(train_list_path), batch_size=BATCH_SIZE)eval_reader = paddle.batch(reader=data_reader(eval_list_path), batch_size=BATCH_SIZE)登录后复制
4.3、模型组网
数据准备的工作完成之后,接下来我们将动手来搭建一个循环神经网络,进行文本特征的提取,从而实现微博谣言检测。
4.3.1 搭建网络
4.3.1.1 rnn
In [10]def rnn_net(ipt, input_dim): #循环神经网络 # 以数据的IDs作为输入, ipt 输入 数据集 emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128],is_sparse=True) drnn = fluid.layers.DynamicRNN() with drnn.block(): # 将embedding标记为RNN的输入,每个时间步取句子中的一个字进行处理 word=drnn.step_input(emb) # 将memory初始化为一个值为0的常量Tensor,shape=[batch_size, 200],其中batch_size由输入embedding决定 memory = drnn.memory(shape=[200]) hidden = fluid.layers.fc(input=[word, memory], size=200, act='relu') # 用hidden更新memory drnn.update_memory(ex_mem=memory, new_mem=hidden) # 将hidden标记为RNN的输出 drnn.output(hidden) # 最大序列池操作 fc = fluid.layers.sequence_pool(input=drnn(), pool_type='max') # 以softmax作为全连接的输出层,大小为2,也就是正负面 out = fluid.layers.fc(input=fc, size=2, act='softmax') return out登录后复制
4.3.1.2 lstm
In [11]def lstm_net(ipt, input_dim): # 长短期记忆网络 # 以数据的IDs作为输入 emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128], is_sparse=True) # 第一个全连接层 fc1 = fluid.layers.fc(input=emb, size=128) # 进行一个长短期记忆操作 lstm1, _ = fluid.layers.dynamic_lstm(input=fc1, #返回:隐藏状态(hidden state),LSTM的神经元状态 size=128) #size=4*hidden_size # 第一个最大序列池操作 fc2 = fluid.layers.sequence_pool(input=fc1, pool_type='max') # 第二个最大序列池操作 lstm2 = fluid.layers.sequence_pool(input=lstm1, pool_type='max') # 以softmax作为全连接的输出层,大小为2,也就是正负面 out = fluid.layers.fc(input=[fc2, lstm2], size=2, act='softmax') return out登录后复制
4.3.1.3 bilstm
In [12]def bilstm_net(ipt, input_dim): # 双向长短期神经网络 # 以数据的IDs作为输入 emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128], is_sparse=True) # 第一个全连接层 fc1_f = fluid.layers.fc(input=emb, size=128) fc1_b = fluid.layers.fc(input=emb, size=128) # 进行一个长短期记忆操作 lstm1_f, _ = fluid.layers.dynamic_lstm(input=fc1_f, #返回:隐藏状态(hidden state),LSTM的神经元状态 size=128) #size=4*hidden_size lstm1_b, _ = fluid.layers.dynamic_lstm(input=fc1_b, #返回:隐藏状态(hidden state),LSTM的神经元状态 is_reverse = True, size=128) #size=4*hidden_size # 第一个最大序列池操作 fc2_f = fluid.layers.sequence_pool(input=fc1_f, pool_type='max') fc2_b = fluid.layers.sequence_pool(input=fc1_b, pool_type='max') # 第二个最大序列池操作 lstm2_f = fluid.layers.sequence_pool(input=lstm1_f, pool_type='max') lstm2_b = fluid.layers.sequence_pool(input=lstm1_b, pool_type='max') lstm2 = fluid.layers.concat(input=[lstm2_f, lstm2_b], axis=1) fc2 = fluid.layers.concat(input=[fc2_f, fc2_b], axis=1) # 以softmax作为全连接的输出层,大小为2,也就是正负面 out = fluid.layers.fc(input=[fc2, lstm2], size=2, act='softmax') return out登录后复制
4.3.1.4 gru
In [13]def gru_net(ipt, input_dim): # 门控循环单元 # 以数据的IDs作为输入 emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128], is_sparse=True) # 第一个全连接层 fc1 = fluid.layers.fc(input=emb, size=384) # 进行一个长短期记忆操作 gru1= fluid.layers.dynamic_gru(input=fc1,size=128) # 第一个最大序列池操作 fc2 = fluid.layers.sequence_pool(input=fc1, pool_type='max') # 第二个最大序列池操作 gru2 = fluid.layers.sequence_pool(input=gru1, pool_type='max') # 以softmax作为全连接的输出层,大小为2,也就是正负面 out = fluid.layers.fc(input=[fc2, gru2], size=2, act='softmax') return out登录后复制
4.3.1.5 bigru
In [14]def bigru_net(ipt, input_dim): # 双向门控循环单元 # 以数据的IDs作为输入 emb = fluid.layers.embedding(input=ipt, size=[input_dim, 128], is_sparse=True) # 第一个全连接层 fc1_f = fluid.layers.fc(input=emb, size=384) fc1_b = fluid.layers.fc(input=emb, size=384) # 进行一个长短期记忆操作 gru1_f= fluid.layers.dynamic_gru(input=fc1_f,size=128) gru1_b= fluid.layers.dynamic_gru(input=fc1_b,size=128,is_reverse=True) # 第一个最大序列池操作 fc2_f = fluid.layers.sequence_pool(input=fc1_f, pool_type='max') fc2_b = fluid.layers.sequence_pool(input=fc1_b, pool_type='max') # 第二个最大序列池操作 gru2_f = fluid.layers.sequence_pool(input=gru1_f, pool_type='max') gru2_b = fluid.layers.sequence_pool(input=gru1_b, pool_type='max') gru2 = fluid.layers.concat(input=[gru2_f, gru2_b], axis=1) fc2 = fluid.layers.concat(input=[fc2_f, fc2_b], axis=1) # 以softmax作为全连接的输出层,大小为2,也就是正负面 out = fluid.layers.fc(input=[fc2, gru2], size=2, act='softmax') return out登录后复制
4.3.2 定义数据层
In [15]# 定义输入数据, lod_level不为0指定输入数据为序列数据paddle.enable_static()words = fluid.data(name='words', shape=[None,1], dtype='int64', lod_level=1) label = fluid.data(name='label', shape=[None,1], dtype='int64')登录后复制
4.3.3 获取分类器
In [16]# 获取数据字典长度dict_dim = get_dict_len(dict_path)# 获取分类器# model = rnn_net(words, dict_dim) # model = lstm_net(words, dict_dim) model = bilstm_net(words, dict_dim) # model = gru_net(words, dict_dim) # model = bigru_net(words, dict_dim)登录后复制
4.3.4 定义损失函数和准确率
定义了一个损失函数之后,还有对它求平均值,因为定义的是一个Batch的损失值。
同时我们还可以定义一个准确率函数,这个可以在我们训练的时候输出分类的准确率。
In [17]# 获取损失函数和准确率cost = fluid.layers.cross_entropy(input=model, label=label)avg_cost = fluid.layers.mean(cost)acc = fluid.layers.accuracy(input=model, label=label)# 获取预测程序test_program = fluid.default_main_program().clone(for_test=True)登录后复制
4.3.5 定义优化方法
In [18]# 定义优化方法optimizer = fluid.optimizer.AdagradOptimizer(learning_rate=0.001)opt = optimizer.minimize(avg_cost)登录后复制
4.4、训练网络
4.4.1 创建Executor
In [19]# use_cuda为False,表示运算场所为CPU;use_cuda为True,表示运算场所为GPU use_cuda = True place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()exe = fluid.Executor(place) # 进行参数初始化exe.run(fluid.default_startup_program())登录后复制
W0221 21:51:10.279976 21097 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0221 21:51:10.283764 21097 device_context.cc:465] device: 0, cuDNN Version: 7.6.登录后复制
[]登录后复制
4.4.2 定义数据映射器
DataFeeder负责将数据提供器(train_reader,test_reader)返回的数据转成一种特殊的数据结构,使其可以输入到Executor中。
feed_list设置向模型输入的向变量表或者变量表名
In [20]# 定义数据映射器feeder = fluid.DataFeeder(place=place, feed_list=[words, label])登录后复制
4.4.3 展示模型训练曲线
In [21]all_train_iter=0all_train_iters=[]all_train_costs=[]all_train_accs=[]all_eval_iter=0all_eval_iters=[]all_eval_costs=[]all_eval_accs=[]def draw_process(title,iters,costs,accs,label_cost,lable_acc): plt.title(title, fontsize=24) plt.xlabel("iter", fontsize=20) plt.ylabel("cost/acc", fontsize=20) plt.plot(iters, costs,color='red',label=label_cost) plt.plot(iters, accs,color='green',label=lable_acc) plt.legend() plt.grid() plt.show()登录后复制4.4.4 训练并保存模型
Executor接收传入的program,并根据feed map(输入映射表)和fetch_list(结果获取表) 向program中添加feed operators(数据输入算子)和fetch operators(结果获取算子)。
feed map为该program提供输入数据。fetch_list提供program训练结束后用户预期的变量。
每一轮训练结束之后,再使用验证集进行验证,并求出相应的损失值Cost和准确率acc。
In [22]EPOCH_NUM=20 #训练轮数model_save_dir = '/home/aistudio/work/infer_model/' #模型保存路径# 开始训练for pass_id in range(EPOCH_NUM): # 进行训练 print('epoch:',int(pass_id)+1) for batch_id, data in enumerate(train_reader()): train_cost, train_acc = exe.run(program=fluid.default_main_program(), feed=feeder.feed(data), fetch_list=[avg_cost, acc]) all_train_iter=all_train_iter+BATCH_SIZE all_train_iters.append(all_train_iter) all_train_costs.append(train_cost[0]) all_train_accs.append(train_acc[0]) if batch_id % 10 == 0: print('Pass:%d, Batch:%d, Cost:%0.5f, Acc:%0.5f' % (pass_id, batch_id, train_cost[0], train_acc[0])) # 进行验证 eval_costs = [] eval_accs = [] for batch_id, data in enumerate(eval_reader()): eval_cost, eval_acc = exe.run(program=test_program, feed=feeder.feed(data), fetch_list=[avg_cost, acc]) eval_costs.append(eval_cost[0]) eval_accs.append(eval_acc[0]) all_eval_iter=all_eval_iter+BATCH_SIZE all_eval_iters.append(all_eval_iter) all_eval_costs.append(eval_cost[0]) all_eval_accs.append(eval_acc[0]) # 计算平均预测损失在和准确率 eval_cost = (sum(eval_costs) / len(eval_costs)) eval_acc = (sum(eval_accs) / len(eval_accs)) print('Test:%d, Cost:%0.5f, ACC:%0.5f' % (pass_id, eval_cost, eval_acc))# 保存模型if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) fluid.io.save_inference_model(model_save_dir, feeded_var_names=[words.name], target_vars=[model], executor=exe)print('训练模型保存完成!') draw_process("train",all_train_iters,all_train_costs,all_train_accs,"trainning cost","trainning acc")draw_process("eval",all_eval_iters,all_eval_costs,all_eval_accs,"evaling cost","evaling acc")登录后复制epoch: 1Pass:0, Batch:0, Cost:0.69911, Acc:0.47266Pass:0, Batch:10, Cost:0.66337, Acc:0.56250Test:0, Cost:0.66470, ACC:0.53683epoch: 2Pass:1, Batch:0, Cost:0.65396, Acc:0.55469Pass:1, Batch:10, Cost:0.63739, Acc:0.62500Test:1, Cost:0.64528, ACC:0.63030epoch: 3Pass:2, Batch:0, Cost:0.63112, Acc:0.70703Pass:2, Batch:10, Cost:0.61620, Acc:0.71875Test:2, Cost:0.62902, ACC:0.69699epoch: 4Pass:3, Batch:0, Cost:0.61294, Acc:0.77344Pass:3, Batch:10, Cost:0.59836, Acc:0.76953Test:3, Cost:0.61558, ACC:0.71670epoch: 5Pass:4, Batch:0, Cost:0.59646, Acc:0.78906Pass:4, Batch:10, Cost:0.58233, Acc:0.79688Test:4, Cost:0.60275, ACC:0.73047epoch: 6Pass:5, Batch:0, Cost:0.58203, Acc:0.80078Pass:5, Batch:10, Cost:0.56806, Acc:0.81250Test:5, Cost:0.59103, ACC:0.76581epoch: 7Pass:6, Batch:0, Cost:0.56877, Acc:0.81641Pass:6, Batch:10, Cost:0.55464, Acc:0.81250Test:6, Cost:0.57990, ACC:0.77269epoch: 8Pass:7, Batch:0, Cost:0.55748, Acc:0.82422Pass:7, Batch:10, Cost:0.54018, Acc:0.83984Test:7, Cost:0.57005, ACC:0.78162epoch: 9Pass:8, Batch:0, Cost:0.54539, Acc:0.83203Pass:8, Batch:10, Cost:0.52911, Acc:0.85156Test:8, Cost:0.56025, ACC:0.78757epoch: 10Pass:9, Batch:0, Cost:0.53370, Acc:0.83594Pass:9, Batch:10, Cost:0.52003, Acc:0.85547Test:9, Cost:0.55105, ACC:0.79846epoch: 11Pass:10, Batch:0, Cost:0.52343, Acc:0.84375Pass:10, Batch:10, Cost:0.50598, Acc:0.87500Test:10, Cost:0.54268, ACC:0.79548epoch: 12Pass:11, Batch:0, Cost:0.51352, Acc:0.84375Pass:11, Batch:10, Cost:0.49801, Acc:0.87891Test:11, Cost:0.53439, ACC:0.80134epoch: 13Pass:12, Batch:0, Cost:0.50305, Acc:0.84766Pass:12, Batch:10, Cost:0.48714, Acc:0.89453Test:12, Cost:0.52690, ACC:0.80627epoch: 14Pass:13, Batch:0, Cost:0.49415, Acc:0.83984Pass:13, Batch:10, Cost:0.47674, Acc:0.89844Test:13, Cost:0.51935, ACC:0.81017epoch: 15Pass:14, Batch:0, Cost:0.48502, Acc:0.84375Pass:14, Batch:10, Cost:0.46768, Acc:0.89844Test:14, Cost:0.51199, ACC:0.82199epoch: 16Pass:15, Batch:0, Cost:0.47742, Acc:0.85156Pass:15, Batch:10, Cost:0.45791, Acc:0.89844Test:15, Cost:0.50501, ACC:0.82394epoch: 17Pass:16, Batch:0, Cost:0.46781, Acc:0.85938Pass:16, Batch:10, Cost:0.45186, Acc:0.89844Test:16, Cost:0.49864, ACC:0.82394epoch: 18Pass:17, Batch:0, Cost:0.45989, Acc:0.86328Pass:17, Batch:10, Cost:0.44218, Acc:0.89453Test:17, Cost:0.49238, ACC:0.82292epoch: 19Pass:18, Batch:0, Cost:0.45258, Acc:0.85938Pass:18, Batch:10, Cost:0.43367, Acc:0.89453Test:18, Cost:0.48672, ACC:0.82487epoch: 20Pass:19, Batch:0, Cost:0.44533, Acc:0.85547Pass:19, Batch:10, Cost:0.42680, Acc:0.89453Test:19, Cost:0.48060, ACC:0.82096训练模型保存完成!登录后复制
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator):/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data登录后复制
登录后复制登录后复制
登录后复制登录后复制
4.5、模型预测
In [23]# 用训练好的模型进行预测并输出预测结果# 创建执行器place = fluid.CPUPlace()infer_exe = fluid.Executor(place)infer_exe.run(fluid.default_startup_program())save_path = '/home/aistudio/work/infer_model/'# 从模型中获取预测程序、输入数据名称列表、分类器[infer_program, feeded_var_names, target_var] = fluid.io.load_inference_model(dirname=save_path, executor=infer_exe)# 获取数据def get_data(sentence): # 读取数据字典 with open('/home/aistudio/data/dict.txt', 'r', encoding='utf-8') as f_data: dict_txt = eval(f_data.readlines()[0]) dict_txt = dict(dict_txt) # 把字符串数据转换成列表数据 keys = dict_txt.keys() data = [] for s in sentence: # 判断是否存在未知字符 if not s in keys: s = '' data.append(int(dict_txt[s])) return datadata = []# 获取图片数据data1 = get_data('兴仁县今天抢小孩没抢走,把孩子母亲捅了一刀,看见这车的注意了,真事,车牌号辽HFM055!!!!!赶紧散播! 都别带孩子出去瞎转悠了 尤其别让老人自己带孩子出去 太危险了 注意了!!!!辽HFM055北京现代朗动,在各学校门口抢小孩!!!110已经 证实!!全市通缉!!')data2 = get_data('重庆真实新闻:2016年6月1日在重庆梁平县袁驿镇发生一起抢儿童事件,做案人三个中年男人,在三中学校到镇街上的一条小路上,把小孩直接弄晕(儿童是袁驿新幼儿园中班的一名学生),正准备带走时被家长及时发现用棒子赶走了做案人,故此获救!请各位同胞们以此引起非常重视,希望大家有爱心的人传递下')data3 = get_data('@尾熊C 要提前预习育儿知识的话,建议看一些小巫写的书,嘻嘻')data.append(data1)data.append(data2)data.append(data3)# 获取每句话的单词数量base_shape = [[len(c) for c in data]]# 生成预测数据tensor_words = fluid.create_lod_tensor(data, base_shape, place)# 执行预测result = exe.run(program=infer_program, feed={feeded_var_names[0]: tensor_words}, fetch_list=target_var)# 分类名称names = [ '谣言', '非谣言']# 获取结果概率最大的labelfor i in range(len(data)): lab = np.argsort(result)[0][i][-1] print('预测结果标签为:%d, 分类为:%s, 概率为:%f' % (lab, names[lab], result[0][i][lab])) 登录后复制预测结果标签为:0, 分类为:谣言, 概率为:0.722868预测结果标签为:0, 分类为:谣言, 概率为:0.754371预测结果标签为:1, 分类为:非谣言, 概率为:0.577661登录后复制代码解释
五、总结
展示的项目使用了RNN,LSTM,BiLSTM,GRU和BiGRU分别进行谣言分类
相比与vanilla RNN(原始RNN),其变体(LSTM,GRU)在实际项目中运用比较广泛
双向循环神经网络虽然可以更全面的结合文本信息进行预测,但是带来了更多参数,减慢了模型训练速度。
采用其他网络的谣言检测任务可以在本项目的基础上修改实现
相关攻略
常见报错解析:“Access Not Configured”故障排除指南 许多开发者和团队成员在使用OpenClaw集成飞书时,都曾遭遇过一个典型的中断提示:“access not configured”(访问未配置)。该提示会明确显示您的飞书账户ID及一组唯一的配对验证码,并指出需要联系机器人所有
OpenClaw 常用指令大全与使用详解 openclaw status:此命令是查看OpenClaw系统整体健康状态的核心指令,执行后即获取服务运行状况的全面报告,是日常运维的首要诊断工具。 openclaw gateway restart:在修改网关配置后,必须运行此指令以重启网关服务,使配置文
如何通过 OpenClaw 实现 Chrome 浏览器自动化操控 在软件开发与自动化测试领域,持续学习是常态。本文旨在详细介绍如何利用 OpenClaw 连接并控制一个已开启的 Chrome 浏览器实例,实现点击、文本输入、文件上传、页面滚动、屏幕截图以及执行 JavaScript 等自动化操作。整
项目概述 你是否希望将强大的 AI 助手带入日常聊天?本教程将指导你完成搭建流程,让你能在 QQ 上直接调用 OpenClaw 智能助手,实现无门槛的 AI 对话体验。 架构说明 ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │ QQ 用户 │ ─
一 下载并安装Node js,全程保持默认设置 首先,请前往Node js官方网站的下载中心:https: nodejs org zh-cn download。根据您的操作系统(Windows Mac Linux)下载对应的安装程序。运行安装向导时,整个过程非常简单,您只需连续点击“下一步”按钮
热门专题
热门推荐
末日生存手游推荐:前往九游开启你的废土冒险之旅 近年来,末日生存题材手游以其独特的沉浸感与生存挑战,持续吸引着大量玩家。在废墟世界中探索资源、应对危机、重建秩序的核心玩法,带来了紧张而富有成就感的游戏体验。如果你正在寻找一款高品质的末日生存手游,九游平台无疑是理想的起点。这里汇集了多款深受好评的末日
《纪念碑谷3》第二关“小镇”超详细图文攻略 《纪念碑谷》系列凭借其独特的视觉艺术与空间谜题设计广受赞誉。最新发布的《纪念碑谷3》在第二章节“小镇”中,将这一美学风格与机关逻辑提升到了新的层次。本章节不仅延续了标志性的极简主义美学,其空间层次感与交互严谨性也更具挑战性。本攻略将为你完整解析《纪念碑谷3
《生存33天》:“沙漠之王”高效通关攻略 在热门生存手游《生存33天》中,玩家面临的挑战远不止于无尽的丧尸潮。游戏深度结合了生存资源管理与高难度首领战策略,其中“沙漠之王”堪称游戏中期最具考验的BOSS。它不仅是实力分水岭,击败后更能获得稀有材料、限定头衔及海量经验金币,大幅推动队伍成长。本文将深入
《生存33天》“四只手”首领完全通关攻略 你是否在“四只手”首领关卡止步不前?不必焦虑,这个Boss在《生存33天》中素有“新秀杀手”之称。初次遭遇时,其独特的机制与高额伤害往往让玩家措手不及,不少冒险者在此耗费了数日时光。然而,只要掌握了它的核心规律,你就会发现这个敌人不过是外强中干。以下这份详尽
《剑与远征:启程》前排坦克英雄赫普深度解析:双形态切换机制与实战搭配指南 在《剑与远征:启程》这款策略放置手游中,组建一支攻守兼备的队伍至关重要,而前排坦克英雄的选择往往是决定胜败的关键。今天,我们将聚焦于蛮血部族的一位特色英雄——赫普。作为一名超稀有品质的坦克,赫普不仅具备坚实的防御力,更凭借独特





