Code Monkey home page Code Monkey logo

text_rnn_attention's Introduction

Text classification with CNN and Word2vec

本文是继自己上的blog“text-cnn”后,基于同样的数据集,嵌入词级别所做的RNN+ATTENTION模型所做的文本分类实验结果;

本实验的主要目是为了探究在同样的数据情况,CNN模型与RNN+attention模型训练的效果对比,训练结果显示在验证集上CNN为96.5%,RNN+attention为96.8%;

有兴趣可以阅读我的:text-cnn

1 环境

python3
tensorflow 1.3以上CPU环境下
gensim
jieba
scipy
numpy
scikit-learn

2 RNN循环神经网络+attention机制

模型RNN+ATTENTION配置的参数在text_model.py中,具体为:

image

模型RNN+ATTENTION大致结构为:

image

3 数据集

本实验同样是使用THUCNews的一个子集进行训练与测试,数据集请自行到THUCTC:一个高效的中文文本分类工具包下载,请遵循数据提供方的开源协议;

文本类别涉及10个类别:categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐'],每个分类6500条数据;

cnews.train.txt: 训练集(5000*10)

cnews.val.txt: 验证集(500*10)

cnews.test.txt: 测试集(1000*10)

训练所用的数据,以及训练好的词向量可以下载:链接: https://pan.baidu.com/s/1gka7SgYIRijSaXgRfYZzwA ,密码: mmbk

4 预处理

本实验主要对训练文本进行分词处理,一来要分词训练词向量,二来输入模型的以词向量的形式;

另外,除掉文本的标点符号,也使用./data/stopwords.txt文件进行停用词过滤;

处理的程序都放在loader.py文件中;

5 运行步骤

python train_word2vec.py,对训练数据进行分词,利用Word2vec训练词向量(vector_word.txt)

python text_train.py,进行训练模型

python text_test.py,对模型进行测试

python text_predict.py,提供模型的预测

6 训练结果

运行:python text_train.py

本实验经过2轮的迭代,满足终止条件结束,在global_step=1500时在验证集得到最佳效果96.8%

image

7 测试结果

运行:python text_test.py

对测试数据集显示,test_loss=0.14,test_accuracy=95.8%,其中“体育”类测试为100%,整体的precision=recall=F1=96%;
而CNN模型的测试结果为:test_loss=0.13,test_accuracy=96.7%,precision=recall=F1=97%

image

8 预测结果

运行:python text_predict.py

随机从测试数据中挑选了五个样本,输出原文本和它的原文本标签和预测的标签,下图中5个样本预测的都是对的;

image

9 对比结论

在与cnn模型对比中发现,训练中在验证集上准确率96.8%是略优于cnn的,但是在测试集上,并没有cnn模型表现的好;我推测的其中原因是,CNN处理文本的长度为600,而RNN+ATTION处理的文本长度为200,而后者也不能处理太长的文本,文本越长,包含的特征信息越多,所以从整体上来看,我个人觉得CNN模型更适合长文本的分类任务。

10 参考

  1. Convolutional Neural Networks for Sentence Classification
  2. gaussic/text-classification-cnn-rnn
  3. YCG09/tf-text-classification

text_rnn_attention's People

Contributors

cjymz886 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.