weilinie / relgan Goto Github PK
View Code? Open in Web Editor NEWImplementation of RelGAN: Relational Generative Adversarial Networks for Text Generation
License: MIT License
Implementation of RelGAN: Relational Generative Adversarial Networks for Text Generation
License: MIT License
Among the RSGAN, standard GAN, and hinge loss which one did you find to be the best for the LSTM based model?
Hi, thank you for sharing the code!
But I have a question for sample generation... In my opinion, Gumbel-Softmax relaxation is used for directly passing gradients back to generator. However, why it is still applied when generating samples?
Thanks ahead!
I noticed that NLL_gen in this repo was calculated using training data such as image_coco sentences as the real distribution. However, the training data was used to pretrain the generator. Is that correct?
I think NLL_gen = E_{P_r test} P_{\theta}(x) should be used instead of
E_{P_r train} P_{\theta}(x)
Hi, thanks for providing the source code!
I reran the code with the command line python oracle_relgan.py 0 <gpud_id>
but got the final nll_oracle
with 7.7672 with length 20 (and 7.4066 several steps before), while that in the paper was ~6.680. All hyper-parameters are remained by default. And I noticed that job_id
to 1 would choose the temperature of 2. The TensorFlow version is 1.14. Models were trained on single NVIDIA Titan RTX.
I would like to know how to reproduce the identical results in the original paper (6.680 for length 20; 6.765 for length 40).
Thank you for the implementation. I wonder if you always observe the same behavior of NLL_gen loss curve with different initialization? I am not observing the same behavior on another code base so I am wondering what I am doing wrong.
Thanks!
As titled, it seems it takes forever in the pre-training stage on the EMNLP2017 WMT News dataset.
What is a reasonable time for that?
Is there any configuration / hyper-parameters that I need to be aware of?
Hope you doing great,
We are training your code on StoryCloze dataset which has 5 sentence stories and we intend to generate stories for example:
Today is Sarah's 5th birthday! Her parents threw her a party at her favorite Mexican restaurant. A balloon artist made balloon hats for Sara and her friends. Sara got lots of presents. The birthday party was a great success.
So we are currently facing few problems:
During Adversarial training the gen_loss is initially at 0.25 for the first 100 epochs but after 100 it suddenly goes to 0.999 and keeps fluctuating there while the dis_loss keeps reducing.
Samples generated during adversarial training are good but when I try to generate new samples using the saved model the quality of the sample reduces and nll_gen goes to 10 though during adversarial training it was around 0.6 . We tried different parameters value including the default ones but still when generating samples the nll_gen shoots up to 10 -12. We ran the code in rmc_vdcnn.
We tried different parameters:
batch_size = 32 or 64
gen_emb_dim = 32
dis_emb_dim = 64
mem_slots = 1
head_size = 512
num_heads = 2
gf_dim = 64 or 32
df_dim = 64 or 32
gsteps = 1 or 3
dsteps = 5
npre_epoch = 200 or 150
navd_steps = 100 or 150
dlr = 1e-6
gpre_lr = 1e-2 -> 1e-4
glr = 1e-4 -> 1e-4
beta max (temperature) = 1000 -> 1000
Do you have any solutions to the above problems
I was trying to get a better understanding of the behavior of g_loss and d_loss during adversarial training.
It appears that while using the oracle generated data, the gen_loss increases and dis_loss goes to zero. With the rsgan loss I thought both losses should converge to 0.7 when the net is trained. Since you haven't reported these losses in the paper, I'm curious what your thoughts are on balanced training.
From the code, I can see that teacher-forcing was used during the pretraining. However, during the adversarial training, it is not used. Was this decision based on any experiments? If yes, can you explain them?
Hello! Thank you so much for your code. I wanted to try this model on the Quora Question Pair Dataset to see if it could generate reliable paraphrases/sentences. Do you think this is possible? Thanks!
RelGAN\oracle\oracle_gan\oracle_train.py", line 265, in get_losses
log_pg = tf.reduce_mean(tf.log(gen_o + EPS)) # [1], measures the log p_g(x)
TypeError: unsupported operand type(s) for +: 'TensorArray' and 'float'
Actually it also happends in real_train.py
I used the same environment ( Tensorflow1.4, python3.5)
Why :(
I am facing these errors. When I switched to tensorflow 1.4, pre-training is taking the infinite time.
Can you check this and share the TF version you used.
Hi,
First of all, thanks for sharing your code! I’m impressive of your solid work. However, I found some issues when I run your code under different hyper-parameters.
Main issues:
gpre_lr=0.005
, while your model behaves normal under gpre_lr=0.01
.temperature=1
.Here’s my system environment.
>> Operating system: Ubuntu 16.04.1
>> Program environment: Virutaulenv
>> Dependencies version:
--Tensorflow 1.5.0
--Numpy 1.14.5
--Scipy 1.1.0
--NLTK 3.4
--tqdm 4.26.0
>> NVIDIA Graphics: TITAN Xp
Here are the problems I encountered when running your code.
For Synthetic data
experiment, I simply change the gpre_lr
from 0.01
to 0.005
. After 1620 epoch adversarial training, the model only generate one repeated sentence. While it behaves normal under gpre_lr=0.01
.
job_id=0
gpu_id=0
architecture='rmc_vanilla'
gantype='RSGAN'
opt_type='adam'
temperature = '2'
d_lr = ’1e-4‘
gadv_lr = ’1e-4‘
mem_slots = ’1‘
head_size = ’256‘
num_head = ’2‘
bs = '64'
seed = '124'
gpre_lr = '0.005' # <<< only change this parameter
hidden_dim = '32'
seq_len = '20'
dataset = 'oracle'
gsteps = '1'
dsteps = '5'
gen_emb_dim = '32'
dis_emb_dim = '64'
num_rep = '64'
sn = False
decay = False
adapt = 'exp'
npre_epochs = '200'
nadv_steps = '3000'
ntest = '20'
experiment-log-relgan.csv
pre_gen_epoch:0, g_pre_loss: 7.8529, time: 21, nll_oracle: 10.0373, nll_gen: 7.7054
pre_gen_epoch:10, g_pre_loss: 6.3980, time: 96, nll_oracle: 9.1713, nll_gen: 7.0273
pre_gen_epoch:20, g_pre_loss: 4.9447, time: 95, nll_oracle: 9.0030, nll_gen: 6.8668
pre_gen_epoch:30, g_pre_loss: 4.0855, time: 97, nll_oracle: 8.9014, nll_gen: 6.6844
pre_gen_epoch:40, g_pre_loss: 3.5856, time: 98, nll_oracle: 8.6974, nll_gen: 6.5024
pre_gen_epoch:50, g_pre_loss: 3.4276, time: 97, nll_oracle: 8.7271, nll_gen: 6.4082
pre_gen_epoch:60, g_pre_loss: 3.1279, time: 97, nll_oracle: 8.5847, nll_gen: 6.0566
pre_gen_epoch:70, g_pre_loss: 2.7486, time: 97, nll_oracle: 8.5072, nll_gen: 6.1834
pre_gen_epoch:80, g_pre_loss: 2.6509, time: 96, nll_oracle: 8.5039, nll_gen: 6.5375
pre_gen_epoch:90, g_pre_loss: 2.3952, time: 98, nll_oracle: 8.4369, nll_gen: 6.5055
pre_gen_epoch:100, g_pre_loss: 2.2010, time: 96, nll_oracle: 8.3912, nll_gen: 6.2355
pre_gen_epoch:110, g_pre_loss: 2.2762, time: 97, nll_oracle: 8.3952, nll_gen: 5.8913
pre_gen_epoch:120, g_pre_loss: 2.1142, time: 96, nll_oracle: 8.3305, nll_gen: 5.6797
pre_gen_epoch:130, g_pre_loss: 1.9376, time: 98, nll_oracle: 8.2759, nll_gen: 5.6209
pre_gen_epoch:140, g_pre_loss: 1.8160, time: 100, nll_oracle: 8.2619, nll_gen: 5.7824
pre_gen_epoch:150, g_pre_loss: 2.0162, time: 99, nll_oracle: 8.3108, nll_gen: 5.8881
pre_gen_epoch:160, g_pre_loss: 1.7529, time: 96, nll_oracle: 8.2722, nll_gen: 6.0981
pre_gen_epoch:170, g_pre_loss: 1.7227, time: 94, nll_oracle: 8.2649, nll_gen: 6.0026
pre_gen_epoch:180, g_pre_loss: 1.8253, time: 96, nll_oracle: 8.3251, nll_gen: 6.3667
pre_gen_epoch:190, g_pre_loss: 1.6638, time: 95, nll_oracle: 8.2878, nll_gen: 6.2536
adv_step: 0, nll_oracle: 8.2616, nll_gen: 5.7822
adv_step: 20, nll_oracle: 8.2877, nll_gen: 5.7955
adv_step: 40, nll_oracle: 8.2676, nll_gen: 5.8229
adv_step: 60, nll_oracle: 8.2573, nll_gen: 5.8185
adv_step: 80, nll_oracle: 8.2349, nll_gen: 5.8068
adv_step: 100, nll_oracle: 8.2085, nll_gen: 5.8215
.
.
.
adv_step: 1500, nll_oracle: 7.6447, nll_gen: 6.6017
adv_step: 1520, nll_oracle: 7.6238, nll_gen: 6.6425
adv_step: 1540, nll_oracle: 7.6210, nll_gen: 6.6755
adv_step: 1560, nll_oracle: 7.6144, nll_gen: 6.7065
adv_step: 1580, nll_oracle: 7.6039, nll_gen: 6.7313
adv_step: 1600, nll_oracle: 7.5998, nll_gen: 6.7501
adv_step: 1620, nll_oracle: 7.5982, nll_gen: 6.7718
adv_step: 1640, nll_oracle: 7.5978, nll_gen: 6.7881
adv_step: 1660, nll_oracle: 7.5963, nll_gen: 6.8009
adv_step: 1680, nll_oracle: 7.5951, nll_gen: 6.8147
adv_step: 1700, nll_oracle: 7.5922, nll_gen: 6.8249
.
.
.
adv_step: 2860, nll_oracle: 8.1041, nll_gen: 7.1184
adv_step: 2880, nll_oracle: 8.1520, nll_gen: 7.1035
adv_step: 2900, nll_oracle: 8.1901, nll_gen: 7.1025
adv_step: 2920, nll_oracle: 8.2444, nll_gen: 7.0996
adv_step: 2940, nll_oracle: 8.2802, nll_gen: 7.0863
adv_step: 2960, nll_oracle: 8.3114, nll_gen: 7.0657
adv_step: 2980, nll_oracle: 8.2215, nll_gen: 7.0739
1620th
adversarial epoch’s samples from adv_samples_01620.txt
. (Only generated repeated sentences)3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
3045 2063 3617 3352 1413 3273 3011 4855 685 2241 2988 1764 4808 391 1755 2243 3526 4197 1604 4541
For Image COCO caption data, I simply change the temperature
from 100
to 1
. The problem of generating repeated sentences arises again. Also, the model generates diverse sentences under temperature=100
.
job_id=0
gpu_id=0
architecture='rmc_vanilla'
gantype='RSGAN'
opt_type='adam'
temperature = '1' # <<< only change this parameter
d_lr = ’1e-4‘
gadv_lr = ’1e-4‘
mem_slots = ’1‘
head_size = ’256‘
num_head = ’2‘
bs = '64'
seed = '124'
gpre_lr = '0.01'
hidden_dim = '32'
seq_len = '20'
dataset = 'oracle'
gsteps = '1'
dsteps = '5'
gen_emb_dim = '32'
dis_emb_dim = '64'
num_rep = '64'
sn = False
decay = False
adapt = 'exp'
npre_epochs = '150'
nadv_steps = '3000'
ntest = '20'
experiment-log-relgan.csv
. For saving time, I didn’t calculate bleu-3 score.pre_gen_epoch:0, g_pre_loss: 2.4170, nll_gen: 1.2337
pre_gen_epoch:10, g_pre_loss: 0.7531, nll_gen: 0.7711
pre_gen_epoch:20, g_pre_loss: 0.6419, nll_gen: 0.6634
pre_gen_epoch:30, g_pre_loss: 0.5984, nll_gen: 0.6540
pre_gen_epoch:40, g_pre_loss: 0.5766, nll_gen: 0.6359
pre_gen_epoch:50, g_pre_loss: 0.5352, nll_gen: 0.6119
pre_gen_epoch:60, g_pre_loss: 0.5106, nll_gen: 0.6105
pre_gen_epoch:70, g_pre_loss: 0.4824, nll_gen: 0.6155
pre_gen_epoch:80, g_pre_loss: 0.4585, nll_gen: 0.6444
pre_gen_epoch:90, g_pre_loss: 0.4533, nll_gen: 0.6171
pre_gen_epoch:100, g_pre_loss: 0.4309, nll_gen: 0.5942
pre_gen_epoch:110, g_pre_loss: 0.4150, nll_gen: 0.6225
pre_gen_epoch:120, g_pre_loss: 0.4064, nll_gen: 0.6629
pre_gen_epoch:130, g_pre_loss: 0.4034, nll_gen: 0.6835
pre_gen_epoch:140, g_pre_loss: 0.3912, nll_gen: 0.6581
Start adversarial training...
adv_step: 0, nll_gen: 0.6736
adv_step: 20, nll_gen: 0.6762
adv_step: 40, nll_gen: 0.6761
adv_step: 60, nll_gen: 0.6766
adv_step: 80, nll_gen: 0.6811
adv_step: 100, nll_gen: 0.6894
adv_step: 120, nll_gen: 0.6988
adv_step: 140, nll_gen: 0.7120
adv_step: 160, nll_gen: 0.7251
adv_step: 180, nll_gen: 0.7389
adv_step: 200, nll_gen: 0.7512
adv_step: 220, nll_gen: 0.7607
adv_step: 240, nll_gen: 0.7719
adv_step: 260, nll_gen: 0.7800
.
.
.
adv_step: 1720, nll_gen: 0.7246
adv_step: 1740, nll_gen: 0.7263
adv_step: 1760, nll_gen: 0.7266
adv_step: 1780, nll_gen: 0.7278
adv_step: 1800, nll_gen: 0.7268
adv_step: 1820, nll_gen: 0.7256
adv_step: 1840, nll_gen: 0.7253
adv_step: 1860, nll_gen: 0.7246
adv_step: 1880, nll_gen: 0.7233
adv_step: 1900, nll_gen: 0.7232
adv_step: 1920, nll_gen: 0.7234
adv_step: 1940, nll_gen: 0.7228
adv_step: 1960, nll_gen: 0.7239
adv_step: 1980, nll_gen: 0.7260
1000th
adversarial epoch’s samples from adv_samples_01000.txt
. (Only generated repeated sentences)a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
a group of people are riding motorcycles on a city street .
Did you encounter the overfitting problem during the generator pre-training? If yes, how did you mitigate it? Also, can you shed some light on the choice of the loss function used during pre-training, and how does it differ from TensorFlow implementation of CategoricalCrossEntropy?
Hi , Thank you for sharing the code. I have a question here :
I saw in your paper you said the EMNLP dataset contains 10000 sentences both training and test data , but why in the repo the file emnlp_news.txt contains 270,000 sentences.
and how you choose 10000 sentences from it ? I didn't see this part in the code..
Thank you : ) !
Closed.
Hi,
Do you plan to release the code under open source licenses?
The original Texygen used a MIT License, which is helpful for making derivations.
Based on my understanding, gpt or gpt-2 are using language model loss to train and generate text, which do not contains GAN.
I am so confused about this question. Thank you very much.
Hello, I'm running the RelGAN code, evaluating diversity using Self-BLEU instead of NLL_gen, that suffers from the already observed issues.
However, I'm getting an unexpected behavior on this metric during training: can anyone please share their own results, just to be sure that I'm not missing anything?
Thanks in advance! :)
In your article, you use the whole test data as reference then calculate the BLEU of each generated sentence. The average of them can be a metric of generated reality.
Conversely, why not use the whole generated data (the same number as test data) as reference then calculate the BLEU of each test sentence. The average of them can be a metric of generated diversity.
Hi,
First of all, thank you for sharing your code! I'm currently performing an empirical evaluation of different GAN approaches, and I would like to include yours. I'm trying to reproduce the results on the EMNLP2017 WMT News
dataset.
python real/experiments/emnlp_relgan.py
, wait for training to finish, and extract the weights of the generator ?Thanks in advance,
Lucas
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.