Code Monkey home page Code Monkey logo

bert-for-tf2's Introduction

BERT for TensorFlow v2

Build Status Coverage Status Version Status Python Versions Downloads

This repo contains a TensorFlow 2.0 Keras implementation of google-research/bert with support for loading of the original pre-trained weights, and producing activations numerically identical to the one calculated by the original model.

ALBERT and adapter-BERT are also supported by setting the corresponding configuration parameters (shared_layer=True, embedding_size for ALBERT and adapter_size for adapter-BERT). Setting both will result in an adapter-ALBERT by sharing the BERT parameters across all layers while adapting every layer with layer specific adapter.

The implementation is build from scratch using only basic tensorflow operations, following the code in google-research/bert/modeling.py (but skipping dead code and applying some simplifications). It also utilizes kpe/params-flow to reduce common Keras boilerplate code (related to passing model and layer configuration arguments).

bert-for-tf2 should work with both TensorFlow 2.0 and TensorFlow 1.14 or newer.

NEWS

  • 11.Oct.2019 - support for loading of the released ALBERT for Chinese pre-trained weights.
  • 10.Oct.2019 - support for ALBERT through the shared_layer=True and embedding_size=128 params.
  • 03.Sep.2019 - walkthrough on fine tuning with adapter-BERT and storing the fine tuned fraction of the weights in a separate checkpoint (see tests/test_adapter_finetune.py).
  • 02.Sep.2019 - support for extending the token type embeddings of a pre-trained model by returning the mismatched weights in load_stock_weights() (see tests/test_extend_segments.py).
  • 25.Jul.2019 - there are now two colab notebooks under examples/ showing how to fine-tune an IMDB Movie Reviews sentiment classifier from pre-trained BERT weights using an adapter-BERT model architecture on a GPU or TPU in Google Colab.
  • 28.Jun.2019 - v.0.3.0 supports adapter-BERT (google-research/adapter-bert) for "Parameter-Efficient Transfer Learning for NLP", i.e. fine-tuning small overlay adapter layers over BERT's transformer encoders without changing the frozen BERT weights.

LICENSE

MIT. See License File.

Install

bert-for-tf2 is on the Python Package Index (PyPI):

pip install bert-for-tf2

Usage

BERT in bert-for-tf2 is implemented as a Keras layer. You could instantiate it like this:

from bert import BertModelLayer

l_bert = BertModelLayer(BertModelLayer.Params(
  vocab_size               = 16000,        # embedding params
  use_token_type           = True,
  use_position_embeddings  = True,
  token_type_vocab_size    = 2,

  num_layers               = 12,           # transformer encoder params
  hidden_size              = 768,
  hidden_dropout           = 0.1,
  intermediate_size        = 4*768,
  intermediate_activation  = "gelu",

  adapter_size             = None,         # see arXiv:1902.00751 (adapter-BERT)

  shared_layer             = False,        # True for ALBERT (arXiv:1909.11942)
  embedding_size           = None,         # None for BERT, wordpiece embedding size for ALBERT

  name                     = "bert"        # any other Keras layer params
))

or by using the bert_config.json from a pre-trained google model:

import tensorflow as tf
from tensorflow import keras

from bert import BertModelLayer
from bert import params_from_pretrained_ckpt
from bert import load_stock_weights

model_dir = ".models/uncased_L-12_H-768_A-12"

bert_params = params_from_pretrained_ckpt(model_dir)
l_bert = BertModelLayer.from_params(bert_params, name="bert")

now you can use the BERT layer in your Keras model like this:

from tensorflow import keras

max_seq_len = 128
l_input_ids      = keras.layers.Input(shape=(max_seq_len,), dtype='int32')
l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32')

# using the default token_type/segment id 0
output = l_bert(l_input_ids)                              # output: [batch_size, max_seq_len, hidden_size]
model = keras.Model(inputs=l_input_ids, outputs=output)
model.build(input_shape=(None, max_seq_len))

# provide a custom token_type/segment id as a layer input
output = l_bert([l_input_ids, l_token_type_ids])          # [batch_size, max_seq_len, hidden_size]
model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output)
model.build(input_shape=[(None, max_seq_len), (None, max_seq_len)])

if you choose to use adapter-BERT by setting the adapter_size parameter, you would also like to freeze all the original BERT layers by calling:

l_bert.apply_adapter_freeze()

and once the model has been build or compiled, the original pre-trained weights can be loaded in the BERT layer:

from bert import load_stock_weights

bert_ckpt_file   = os.path.join(model_dir, "bert_model.ckpt")
load_stock_weights(l_bert, bert_ckpt_file)

N.B. see tests/test_bert_activations.py for a complete example.

Resources

  • BERT - BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
  • adapter-BERT - adapter-BERT: Parameter-Efficient Transfer Learning for NLP
  • ALBERT - ALBERT: A Lite BERT for Self-Supervised Learning of Language Representations
  • google-research/bert - the original BERT implementation
  • kpe/params-flow - A Keras coding style for reducing Keras boilerplate code in custom layers by utilizing kpe/py-params

bert-for-tf2's People

Contributors

kpe avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.