Code Monkey home page Code Monkey logo

tfbert's Introduction

tfbert

  • 基于tensorflow 1.x 的bert系列预训练模型工具
  • 支持多GPU训练,支持梯度累积,支持pb模型导出,自动剔除adam参数
  • 采用dataset 和 string handle配合,可以灵活训练、验证、测试,在训练阶段也可以使用验证集测试模型,并根据验证结果保存参数。

说明

config、tokenizer参考的transformers的实现。

内置有自定义的Trainer,像pytorch一样使用tensorflow1.14,具体使用下边会介绍。

目前内置 文本分类文本多标签分类命名实体识别例子。

内置的几个例子的数据处理代码都支持多进程处理,实现方式参考的transformers。

内置代码示例数据集百度网盘提取码:rhxk

支持模型

bert、electra、albert、nezha、wobert、ChineseBert(GlyceBert)

requirements

tensorflow==1.x
tqdm
jieba

目前本项目都是在tensorflow 1.x下实现并测试的,最好使用1.14及以上版本,因为内部tf导包都是用的

import tensorflow.compat.v1 as tf

使用说明

Config 和 Tokenizer

使用方法和transformers一样

from tfbert import BertTokenizer, BertConfig

config = BertConfig.from_pretrained('config_path')
tokenizer = BertTokenizer.from_pretrained('vocab_path', do_lower_case=True)

inputs = tokenizer.encode_plus(
'测试样例', text_pair=None, max_length=128, padding="max_length", add_special_tokens=True)

config.save_pretrained("save_path")
tokenizer.save_pretrained("save_path")

多卡运行方式,需要设置环境变量CUDA_VISIBLE_DEVICES,内置trainer会读取参数:

CUDA_VISIBLE_DEVICES=1,2 python run.py

详情查看代码样例

XLA和混合精度训练训练速度测试

使用哈工大的rbt3权重进行实验对比,数据为example中的文本分类数据集。 开启xla和混合精度后刚开始训练需要等待一段时间优化,所以第一轮会比较慢, 等开启后训练速度会加快很多。最大输入长度32,批次大小32,训练3个epoch, 测试环境为tensorflow1.14,GPU是2080ti。

use_xla mixed_precision first epoch (s/epoch) second epoch (s/epoch) eval accuracy
False False 76 61 0.9570
True False 73 42 0.9584
True True 85 37 0.9582

开启混合精度比较慢,base版本模型的话需要一两分钟,但是开启后越到后边越快,训练步数少的话可以只开启xla就行了,如果多的话 最好xla和混合精度(混合精度前提是你的卡支持fp16)都打开。

可加载中文权重链接

模型简称 下载链接
BERT wwm 系列 Chinese-BERT-wwm
BERT-base, ChineseGoogle Google Cloud
ALBERT-base, ChineseGoogle google-research/albert
MacBERT, Chinese MacBERT
ELECTRA, Chinese Chinese-ELECTRA
ERNIE 1.0.1, Chinese 百度网盘(xrku)
ERNIE gram base, Chinese 百度网盘(7xet)
ChineseBert, Chinese base(sxhj) large(zi0r)

更新记录

  • 2021/7/31 内置模型新增香侬科技开源的ChineseBert,见glyce_bert,目前官方只有torch版本。 模型增加了字形和拼音特征作为embedding表示,获得了和mac bert接近的效果,官方见ChineseBert 。tf权重已经转好,可自行下载。

  • 2021/5/19 增加机器阅读理解示例代码,以dureader2021比赛数据为例,应该兼容大部分squad格式的数据。 同时更新tokenizer代码,贴合transformers使用接口,大部分直接整合的transformers的tokenizer

  • 2021/5/9 增加fgm,pgd,freelb接口,代码见adversarial.py, 使用方式,在trainer的build_model中传入adversarial_type即可,这两天没GPU和相应数据集,所以功能还没测试。

  • 2021/4/18 花了一天时间重整Trainer,新增一个Dataset类。由于更新有点多,还没来得及写太多注释,敬请见谅。具体更新:

    1. trainer封装了train、evaluate、predict方法,具体见新版的使用例子。
    2. 写了一个Dataset类,支持简单的数据包装,也可以直接导出tf的dataset类型,具体dataset.py.
    3. 去除了原版需要自定义shapes和types的方式(原有data代码还没删),都可以通过新增的Dataset类下的方法直接自行获取。
  • 2021/4/17 新增SimpleTrainer,采用feed dict的方式进行调用,操作简单,但是相比Trainer的dataset方式要慢好多, 随便写了个例子simple_trainer.py,以后有时间再完善

  • tf.layers.dropout 需要将training设置为None才会根据tf.keras.backend.learning_phase()进行mode判定。 之前默认的training为False,dropout都没起作用,非常抱歉。

  • 增加resize_word_embeddings方法,可对已保存权重文件的embedding部分就行词表大小修改。 具体见resize_word_embeddings方法

  • 对抗训练暂不可用...代码实现错误

  • 2021年2月22日 增加FGM对抗训练方式,可以在trainer.build_model()时设置use_fgm为True, 即可开启fgm对抗训练,目前未测试效果。

  • 2021年2月8日 毕业论文写完了,花了点时间进行大更新,此次更新对原有代码重组,进一步提升训练速度。使用NVIDIA的方法修改梯度累积方式,去除backward、 zero_grad方法,统一使用train_step训练。梯度累积只需要在配置优化节点时传入梯度累积步数即可。 最后,代码增加xla加速和混合精度训练,混合精度目前只支持部分gpu,支持情况自行百度。 最后详情请自行看使用例子对比。

  • 2020年11月14日 增加xla加速模块,可以在trainer设定use_xla传参决定是否开启,开启后可以加速训练。backward、zero_grad、train_step模式增加开启关闭操作, 可以在trainer设定use_torch_mode决定是否取消该模式,取消后不支持梯度累积,直接调用train_step进行训练, 这样会加快训练速度。

  • 2020年9月23日 增加梯度累积,采用trainer.backward(), trainer.zero_grad(), trainer.train_step() 一同进行训练,参考pytorch训练方式。

  • 2020年9月21日 第一次上传,支持模型bert、albert、electra、nezha、wobert。

Reference

  1. Transformers: State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch.
  2. TensorFlow code and pre-trained models for BERT
  3. ALBERT: A Lite BERT for Self-supervised Learning of Language Representations
  4. NEZHA-TensorFlow
  5. ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
  6. 基于词颗粒度的中文WoBERT
  7. NVIDIA/BERT模型使用方案

tfbert's People

Contributors

huiresearch avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

tfbert's Issues

UNILM实现

您好!,很赞的项目,请问后续有实现UNILM的打算吗

I got those error when I was run the python file of run_element_extract.py.

/opt/conda/lib/python3.7/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
return array(a, dtype, copy=False, order=order)
Traceback (most recent call last):
File "run_element_extract.py", line 277, in
main()
File "run_element_extract.py", line 234, in main
trainer.build_model(model_fn=get_model_fn(config, args))
File "/opt/tfbert/tfbert/trainer.py", line 621, in build_model
model_output = model_fn(inputs, True)
File "run_element_extract.py", line 126, in model_fn
**inputs
File "/opt/tfbert/tfbert/models/for_task.py", line 173, in init
compute_type=compute_type
File "/opt/tfbert/tfbert/models/bert.py", line 152, in init
input_shape = model_utils.get_shape_list(input_ids, expected_rank=2)
File "/opt/tfbert/tfbert/models/model_utils.py", line 196, in get_shape_list
assert_rank(tensor, expected_rank, name)
File "/opt/tfbert/tfbert/models/model_utils.py", line 241, in assert_rank
(name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
ValueError: For the tensor IteratorGetNext:1 in scope ``, the actual rank 1 (shape = (?,)) is not equal to the expected rank `

The problem happened when I use chinese_bert_chinese_wwm_L-12_H-768_A-12 model with run_elment_extract.py script.

WARNING:tensorflow:From /opt/tfbert/tfbert/models/layers.py:28: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.Dense instead.
Traceback (most recent call last):
File "/opt/tfbert/run_element_extract.py", line 292, in
main()
File "/opt/tfbert/run_element_extract.py", line 253, in main
args.model_dir if args.pretrained_checkpoint_path is None else args.pretrained_checkpoint_path)
File "/opt/tfbert/tfbert/trainer.py", line 175, in from_pretrained
utils.init_checkpoints(ckpt, True)
File "/opt/tfbert/tfbert/utils.py", line 261, in init_checkpoints
prefix=prefix)
File "/opt/tfbert/tfbert/utils.py", line 239, in get_assignment_map_from_checkpoint
init_vars = tf.train.list_variables(init_checkpoint)
File "/opttensorflow_core/python/training/checkpoint_utils.py", line 97, in list_variables
reader = load_checkpoint(ckpt_dir_or_file)
File "/opttensorflow_core/python/training/checkpoint_utils.py", line 66, in load_checkpoint
return pywrap_tensorflow.NewCheckpointReader(filename)
File "/opttensorflow_core/python/pywrap_tensorflow_internal.py", line 873, in NewCheckpointReader
return CheckpointReader(compat.as_bytes(filepattern))
File "/opttensorflow_core/python/pywrap_tensorflow_internal.py", line 885, in init
this = _pywrap_tensorflow_internal.new_CheckpointReader(filename)
tensorflow.python.framework.errors_impl.DataLossError: Unable to open table file /opt/models/bert/chinese_bert_chinese_wwm_L-12_H-768_A-12/publish/bert_model.ckpt.data-00000-of-00001: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?

加载模型问题

NER任务加载模型进行测试比直接训练之后的测试降了五个点。
加载完模型之后不进行初始化会报错,请问是否是部分参数未加载导致的。

多卡会报错

使用楼主代码运行,多卡会报错,tf.split 分发数据出错

对抗训练保存模型

hello,有一个疑问,目前的模型对抗训练保存的文件是否还是会前向两次,按理说应该是训练的时候前向两次,验证和测试都只前向一次,不需要对抗

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.