Code Monkey home page Code Monkey logo

tf-seq2seq's Introduction

TF-seq2seq

Sequence to sequence (seq2seq) learning Using TensorFlow.

The core building blocks are RNN Encoder-Decoder architectures and Attention mechanism.

The package was largely implemented using the latest (1.2) tf.contrib.seq2seq modules

  • AttentionWrapper
  • Decoder
  • BasicDecoder
  • BeamSearchDecoder

The package supports

  • Multi-layer GRU/LSTM
  • Residual connection
  • Dropout
  • Attention and input_feeding
  • Beamsearch decoding
  • Write n-best list

Dependencies

  • NumPy >= 1.11.1
  • Tensorflow >= 1.2

History

  • June 5, 2017: Major update
  • June 6, 2017: Supports batch beamsearch decoding
  • June 11, 2017: Separted training / decoding
  • June 22, 2017: Supports tf.1.2 (contrib.rnn -> python.ops.rnn_cell)

Usage Instructions

Data Preparation

To preprocess raw parallel data of sample_data.src and sample_data.trg, simply run

cd data/
./preprocess.sh src trg sample_data ${max_seq_len}

Running the above code performs widely used preprocessing steps for Machine Translation (MT).

  • Normalizing punctuation
  • Tokenizing
  • Bytepair encoding (# merge = 30000) (Sennrich et al., 2016)
  • Cleaning sequences of length over ${max_seq_len}
  • Shuffling
  • Building dictionaries

Training

To train a seq2seq model,

$ python train.py   --cell_type 'lstm' \ 
                    --attention_type 'luong' \
                    --hidden_units 1024 \
                    --depth 2 \
                    --embedding_size 500 \
                    --num_encoder_symbols 30000 \
                    --num_decoder_symbols 30000 ...

Decoding

To run the trained model for decoding,

$ python decode.py  --beam_width 5 \
                    --decode_batch_size 30 \
                    --model_path $PATH_TO_A_MODEL_CHECKPOINT (e.g. model/translate.ckpt-100) \
                    --max_decode_step 300 \
                    --write_n_best False
                    --decode_input $PATH_TO_DECODE_INPUT
                    --decode_output $PATH_TO_DECODE_OUTPUT
                    

If --beam_width=1, greedy decoding is performed at each time-step.

Arguments

Data params

  • --source_vocabulary : Path to source vocabulary
  • --target_vocabulary : Path to target vocabulary
  • --source_train_data : Path to source training data
  • --target_train_data : Path to target training data
  • --source_valid_data : Path to source validation data
  • --target_valid_data : Path to target validation data

Network params

  • --cell_type : RNN cell to use for encoder and decoder (default: lstm)
  • --attention_type : Attention mechanism (bahdanau, luong), (default: bahdanau)
  • --depth : Number of hidden units for each layer in the model (default: 2)
  • --embedding_size : Embedding dimensions of encoder and decoder inputs (default: 500)
  • --num_encoder_symbols : Source vocabulary size to use (default: 30000)
  • --num_decoder_symbols : Target vocabulary size to use (default: 30000)
  • --use_residual : Use residual connection between layers (default: True)
  • --attn_input_feeding : Use input feeding method in attentional decoder (Luong et al., 2015) (default: True)
  • --use_dropout : Use dropout in rnn cell output (default: True)
  • --dropout_rate : Dropout probability for cell outputs (0.0: no dropout) (default: 0.3)

Training params

  • --learning_rate : Number of hidden units for each layer in the model (default: 0.0002)
  • --max_gradient_norm : Clip gradients to this norm (default 1.0)
  • --batch_size : Batch size
  • --max_epochs : Maximum training epochs
  • --max_load_batches : Maximum number of batches to prefetch at one time.
  • --max_seq_length : Maximum sequence length
  • --display_freq : Display training status every this iteration
  • --save_freq : Save model checkpoint every this iteration
  • --valid_freq : Evaluate the model every this iteration: valid_data needed
  • --optimizer : Optimizer for training: (adadelta, adam, rmsprop) (default: adam)
  • --model_dir : Path to save model checkpoints
  • --model_name : File name used for model checkpoints
  • --shuffle_each_epoch : Shuffle training dataset for each epoch (default: True)
  • --sort_by_length : Sort pre-fetched minibatches by their target sequence lengths (default: True)

Decoding params

  • --beam_width : Beam width used in beamsearch (default: 1)
  • --decode_batch_size : Batch size used in decoding
  • --max_decode_step : Maximum time step limit in decoding (default: 500)
  • --write_n_best : Write beamsearch n-best list (n=beam_width) (default: False)
  • --decode_input : Input file path to decode
  • --decode_output : Output file path of decoding output

Runtime params

  • --allow_soft_placement : Allow device soft placement
  • --log_device_placement : Log placement of ops on devices

Acknowledgements

The implementation is based on following projects:

  • nematus: Theano implementation of Neural Machine Translation. Major reference of this project
  • subword-nmt: Included subword-unit scripts to preprocess input data
  • moses: Included preprocessing scripts to preprocess input data
  • tf.seq2seq_legacy Legacy Tensorflow seq2seq tutorial
  • tf_tutorial_plus: Nice tutorials for tf.contrib.seq2seq API

For any comments and feedbacks, please email me at [email protected] or open an issue here.

tf-seq2seq's People

Contributors

jayparks avatar

Watchers

Apurv Verma avatar James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.