Code Monkey home page Code Monkey logo

text-style-transfer-gan's Introduction

text-style-transfer-gan

Style transfer in text using cycle-consistent WGANs

Architecture

Requirements

Usage

Instruction for running

For training on YAFC dataset:

  1. Pretrain a LM for both formal and informal styles:
$ python main.py --batch_size 20 \
        --dataFile data/yafc_formal.h5 \
        --jsonFile data/yafc_formal.json \
        --shuffle True --train_mode pretrain_lm --embedding_size 300 \
        --hidden_size 350 --num_rnn_layers 1 --use_lstm True \
        --epochs 100 --lr 1e-4 --weight_decay 1e-4 \
        --dropout_p 0.5 --max_norm 10 \
        --log_dir logs/pretrain_lm/yafc_formal \
        --num_sample_sents 5 \
        --save_path models/pretrain_lm/yafc_formal --model_name model

Similarily for informal,

$ python main.py --batch_size 20 \
        --dataFile data/yafc_informal.h5 \
        --jsonFile data/yafc_informal.json \
        --shuffle True --train_mode pretrain_lm --embedding_size 300 \
        --hidden_size 350 --num_rnn_layers 1 --use_lstm True \
        --epochs 100 --lr 1e-4 --weight_decay 1e-4 \
        --dropout_p 0.5 --max_norm 10 \
        --log_dir logs/pretrain_lm/yafc_informal \
        --num_sample_sents 5 \
        --save_path models/pretrain_lm/yafc_informal --model_name model
  1. Pretrain Seq2Seq model using MLE training that converts s1 to s2 and s2 back to s1 (we load pretrained LM weights to initialize generators):
$ python main.py --batch_size 128 \
        --dataFile data/yafc_formal.h5 \
        --jsonFile data/yafc_formal.json \
        --pdataFile data/yafc_informal.h5 \
        --pjsonFile data/yafc_informal.json \
        --shuffle True --train_mode train_seq2seq --embedding_size 300 \
        --hidden_size 350 --num_rnn_layers 1 --use_lstm True \
        --epochs 100 --lr 1e-4 --weight_decay 1e-4 \
        --dropout_p 0.2 --max_norm 10 \
        --log_dir logs/train_seq2seq \
        --num_sample_sents 5 \
        --save_path models/train_seq2seq --model_name model\
        --pretrained_lm1_model_path models/pretrain_lm/yafc_formal/model_best.net \
        --pretrained_lm2_model_path models/pretrain_lm/yafc_informal/model_best.net \
        --skip_weight_decay 0.995 \
        --log_iter 10 --sent_sample_iter 100
  1. Finally, train the Seq2Seq model in finetune_cyclegan mode:
$ python main.py --batch_size 128 \
        --dataFile data/yafc_formal.h5 \
        --jsonFile data/yafc_formal.json \
        --pdataFile data/yafc_informal.h5 \
        --pjsonFile data/yafc_informal.json \
        --shuffle True --train_mode finetune_cyclegan --embedding_size 300 \
        --hidden_size 350 --num_rnn_layers 1 --use_lstm True --use_attention True\
        --epochs 100 --lr 5e-6 --weight_decay 1e-4 \
        --dropout_p 0.2 --max_norm 1 \
        --log_dir logs/finetune_cyclegan/ \
        --num_sample_sents 5 --save_path models/finetune_cyclegan/ --model_name model\
        --pretrained_lm1_model_path models/pretrain_lm/yafc_formal/model_best.net \
        --pretrained_lm2_model_path models/pretrain_lm/yafc_informal/model_best.net \
        --pretrained_seq2seq_model_path models/train_seq2seq/model_best.net \
        --num_searches 1 --g_update_step_diff 25 --d_update_step_diff 1 \
        --lr_ratio_D_by_G 20.0 --discount_factor 0.99 \
        --lamda_rl 1e-0 --lamda_rec_ii 1e-2 --lamda_rec_ij 1e-3 \
        --lamda_cos_ij 1e-1 \
        --freeze_embeddings True --clamp_lower -0.01 --clamp_upper 0.01 \
        --d_pretrain_num_epochs 3  --disc_recalibrate 400\
        --g_update_step_diff_recalib 200 \
        --log_iter 10 --sent_sample_iter 100

Evaluation

We evaluate our models on BLEU score with n ranging between 1 and 4:

$ python eval.py --model_path models/finetune_cyclegan/model_best.net \
        --dataFile data/yafc_formal.h5 \
        --jsonFile data/yafc_formal.json \
        --pdataFile data/yafc_informal.h5 \
        --pjsonFile data/yafc_informal.json \
        --split val_and_test

text-style-transfer-gan's People

Contributors

p-kar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

text-style-transfer-gan's Issues

assert (glove_reader.vect_size == opts.embedding_size), "Embedding size mismatch"

python main.py --batch_size 20
--dataFile misc/shakespeare_original.h5
--jsonFile misc/shakespeare_original.json
--shuffle True --train_mode pretrain_lm --embedding_size 300
--hidden_size 350 --num_rnn_layers 1 --use_lstm True
--epochs 100 --lr 1e-4 --weight_decay 1e-4
--dropout_p 0.5 --max_norm 10
--log_dir logs/pretrain_lm/shakespear_formal
--num_sample_sents 5
--save_path models/pretrain_lm/shakespear_formal --model_name model
======================= Parameters =======================

batch_size : 20
dataFile : misc/shakespeare_original.h5
jsonFile : misc/shakespeare_original.json
pdataFile : misc/shakespeare_modern.h5
pjsonFile : misc/shakespeare_modern.json
shuffle : True
train_mode : pretrain_lm
embedding_size : 300
hidden_size : 350
num_rnn_layers : 1
use_lstm : True
use_attention : True
epochs : 100
lr : 0.0001
weight_decay : 0.0001
dropout_p : 0.5
max_norm : 10.0
enable_scheduled_sampling : False
scheduled_sampling_decay_type : linear
scheduled_sampling_decay_factor: 0.005
scheduled_sampling_min_eps : 0.5
log_dir : logs/pretrain_lm/shakespear_formal
num_sample_sents : 5
log_iter : 10
sent_sample_iter : 50
save_path : models/pretrain_lm/shakespear_formal
model_name : model
pretrained_lm1_model_path : ./models/shakespeare_formal_lm_best.net
pretrained_lm2_model_path : ./models/shakespeare_informal_lm_best.net
pretrained_seq2seq_model_path : ./models/shakespeare_formal_informal.net
pretrained_glove_vector_path : ./data/glove/glove.twitter.27B.200d.txt
use_glove_embeddings : True
num_searches : 1
g_update_step_diff : 25
d_update_step_diff : 1
lr_ratio_D_by_G : 1.0
discount_factor : 0.99
lamda_rl : 1.0
lamda_rec_ii : 0.01
lamda_rec_ij : 0.001
lamda_cos_ij : 0.1
skip_weight_decay : 0.995
freeze_embeddings : True
clamp_lower : -0.01
clamp_upper : 0.01
d_pretrain_num_epochs : 3
disc_recalibrate : 100
g_update_step_diff_recalib : 200

############## Language Model Pretraining ##############
########################################################

Loading GloVe vectors from ./data/glove/glove.twitter.27B.200d.txt... done
Traceback (most recent call last):
File "main.py", line 1470, in
main(opts)
File "main.py", line 1385, in main
pretrain_language_model(opts)
File "main.py", line 132, in pretrain_language_model
load_embedding_from_glove(opts, loader, embedding)
File "main.py", line 110, in load_embedding_from_glove
assert (glove_reader.vect_size == opts.embedding_size), "Embedding size mismatch"
AssertionError: Embedding size mismatch

this line 47 seems corrupted - preprocess_yafc.py

transtab = str.maketrans('éè—üəāöóŕ–ƒùšïĕ§†ûàäñıáí', 'ee ueaoor fusiestuaaniai', string.punctuation + '¢®£©´™“’”¿¨…\u200b綠嘉豆加因人為義薏仁»œ►•·ºس恭¹♡¡λ發‘˝◄½م喜♥☺æها財')

not an issue - question regarding research - leakGAN

are you familiar with LeakGAN?
https://arxiv.org/abs/1709.08624

Recently, by combining with policy gradient, Generative Adversarial Nets
(GAN) that use a discriminative model to guide the training of the generative model as a reinforcement learning policy has shown promising results in text generation. However,
the scalar guiding signal is only available after the entire text has been generated and lacks intermediate information about text structure during the generative process. As such, it limits
its success when the length of the generated text samples is long (more than 20 words).
We allow the discriminative net to leak its own high-level extracted features to the generative net to
further help the guidance.

I looked at diagram and wondered if this was considered to improve results / performance.

screen shot 2018-07-10 at 10 49 17 pm

yafc_formal.h5 - how to create this file?

python main.py
======================= Parameters =======================

batch_size : 128
dataFile : data/yafc_formal.h5
jsonFile : data/yafc_formal.json
pdataFile : data/yafc_informal.h5
pjsonFile : data/yafc_informal.json
shuffle : True
train_mode : pretrain_lm
embedding_size : 300
hidden_size : 350
num_rnn_layers : 1
use_lstm : True
use_attention : True
epochs : 1000
lr : 1e-05
weight_decay : 0.0001
dropout_p : 0.2
max_norm : 1
enable_scheduled_sampling : False
scheduled_sampling_decay_type : linear
scheduled_sampling_decay_factor: 0.005
scheduled_sampling_min_eps : 0.5
log_dir : ./logs
num_sample_sents : 5
log_iter : 10
sent_sample_iter : 50
save_path : ./models
model_name : model
pretrained_lm1_model_path : ./models/yafc_formal_lm_best.net
pretrained_lm2_model_path : ./models/yafc_informal_lm_best.net
pretrained_seq2seq_model_path : ./models/yafc_formal_informal.net
pretrained_glove_vector_path : ./data/glove/glove.twitter.27B.200d.txt
use_glove_embeddings : True
num_searches : 1
g_update_step_diff : 25
d_update_step_diff : 1
lr_ratio_D_by_G : 1.0
discount_factor : 0.99
lamda_rl : 1.0
lamda_rec_ii : 0.01
lamda_rec_ij : 0.001
lamda_cos_ij : 0.1
skip_weight_decay : 0.995
freeze_embeddings : True
clamp_lower : -0.01
clamp_upper : 0.01
d_pretrain_num_epochs : 3
disc_recalibrate : 100
g_update_step_diff_recalib : 200

############## Language Model Pretraining ##############
########################################################

Traceback (most recent call last):
File "main.py", line 1466, in
main(opts)
File "main.py", line 1385, in main
pretrain_language_model(opts)
File "main.py", line 122, in pretrain_language_model
loader = DataLoader(opts)
File "misc/DataLoader.py", line 11, in init
self.h5_file = h5py.File(opts.dataFile, 'r')
File "/Users/admin/miniconda3/envs/pytorch/lib/python3.6/site-packages/h5py/_hl/files.py", line 312, in init
fid = make_fid(name, mode, userblock_size, fapl, swmr=swmr)
File "/Users/admin/miniconda3/envs/pytorch/lib/python3.6/site-packages/h5py/_hl/files.py", line 142, in make_fid
fid = h5f.open(name, flags, fapl=fapl)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5f.pyx", line 78, in h5py.h5f.open
OSError: Unable to open file (unable to open file: name = 'data/yafc_formal.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)
(pytorch) ➜ text-style-transfer-gan git:(master) ✗

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.