devnag / pytorch-generative-adversarial-networks Goto Github PK
View Code? Open in Web Editor NEWA very simple generative adversarial network (GAN) in PyTorch
License: Apache License 2.0
A very simple generative adversarial network (GAN) in PyTorch
License: Apache License 2.0
On line 152 in gan_pytorch.py
it says:
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1]))) # zeros = fake
it should say
...torch.zeros([0,0]))
right?
in the code we have "d_input_size = 100", which is very strange, the input dimension of the discriminator should be same as the output dimension of the generator, which is 1
Nothing is being detached, so does this train the discriminator?
When using the raw data [ (name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x) ], the 'fake' distribution can not match the 'real' one even under very large num_epochs .
I think this is probably because of the updates, but torch.mean() no longer accepts the keep_dim = True argument.
I removed it and the code runs fine. Probably you need to remove it in the code!
The best results is on about 18400 epochs. After that the mean value of faked data increases quickly. Is that any early stop strategy when training GAN?
Can you please explain what does decorate_with_diffs() does ? Thanks.
I have seen your article ”Generative Adversarial Networks (GANs) in 50 lines of code (PyTorch)“ in Medium. Your work inspire me a lot, so I want to cite your work in my paper. May I ask if there is a corresponding paper for this work?
[edit] The issue was that I was not training long enough.
I cloned this exact code. Just removed the keep_dims=True in line 91 and then plotted the output of the generator during training (even after a number of epochs). The generated distribution's mean and std have converged to the values of the real distribution (that's cool!) BUT the distribution (when plotted) does not look like a Gaussian. It does not look like the gaussian plot you have in your blog. Here is my plotting code. Any ideas why this happens?
if (epoch+1) % 1000 == 0:
plt.clf()
gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_sample_data = G(gen_input)
d_fake_data_numpy = g_sample_data.data.cpu().numpy()
d_real_data_numpy = d_real_data.data.cpu().numpy()
p1, bins, patches = plt.hist(d_real_data_numpy.flatten(), 20, normed=1, facecolor='r', alpha=0.75)
p2, bins, patches = plt.hist(d_fake_data_numpy.flatten(), 20, normed=1, facecolor='b', alpha=0.75)
plt.show(True)
After 10,000 epochs it looks like this (red is a real distribution and blue is the generated). Despite the mean and std of the distribution being spot on, it doesnt take the shape/form of a Gaussian distribution:
and 20,000 epochs:
and 30,000 epochs:
and 40,000 epochs:
and 50,000 epochs:
I realized it starts to look pretty good after 60,000 runs (if it has not diverged. It sometimes does).
I was just not training long enough.
and then 80,000 runs:
Sorry,I want to get detailed information from your code,but I can't open the link.I don't know why,can you help me?
i want to run this code on my dataset , how can i do it ?
我长期研究和改进GAN,如果对GAN或者深度学习感兴趣的可以联系我,联系方式,wechat: lovedaixiaobaby
In code gan_pytorch.py
, the default mean is 4. When you run this code, it seems correct since the number is small. But if you change the mean to 400, the generator will lose direction. This is because get_generator_input_sampler
generate a uniform data as the input. The function should be changed from lambda m, n: torch.rand(m, n)
to lambda m, n: torch.randn(m, n)
, which generate a normal distribution rather than a random one.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.