kimmo1019 / roundtrip Goto Github PK
View Code? Open in Web Editor NEWRoundtrip: density estimation with deep generative neural networks
Home Page: https://pypi.org/project/pyroundtrip/
License: MIT License
Roundtrip: density estimation with deep generative neural networks
Home Page: https://pypi.org/project/pyroundtrip/
License: MIT License
When I want to visualize the estimated density on a 2D region, an error occurred:
Traceback (most recent call last):
File "evaluate.py", line 233, in
RTM = load_model(path,epoch,pretrain=args.pretrain)
File "evaluate.py", line 197, in load_model
xs = util.Gaussian_sampler(mean=np.zeros(x_dim),sd=1.0)
TypeError: init() takes at least 3 arguments (3 given)
And I add N=20000
at the line 197 in "evaluate.py", like:
xs = util.Gaussian_sampler(N=20000,mean=np.zeros(x_dim),sd=1.0)
The code can run successfully.
I am running codes using TF 1.31.1 Python 3.7 venv. when I try to run example (1), when the Cross Validation starts I got runtime issues.
I used following command for faster run time debugging.
CUDA_VISIBLE_DEVICES=0 python main_density_est.py --dx 2 --dy 2 --train True --data indep_gmm --epochs 10 --cv_epoch 3 --patience 5
When the CV starts at epoch 3 the following runtime warning shows
/home/connor/Roundtrip/venv/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3420: RuntimeWarning: Mean of empty slice. out=out, **kwargs)
/home/connor/Roundtrip/venv/lib/python3.7/site-packages/numpy/core/_methods.py:188: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
/home/connor/Roundtrip/venv/lib/python3.7/site-packages/numpy/core/_methods.py:262: RuntimeWarning: Degrees of freedom <= 0 for slice
keepdims=keepdims, where=where)
/home/connor/Roundtrip/venv/lib/python3.7/site-packages/numpy/core/_methods.py:222: RuntimeWarning: invalid value encountered in true_divide
subok=False)
/home/connor/Roundtrip/venv/lib/python3.7/site-packages/numpy/core/_methods.py:253: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
However the program keep running until it reaches last 2 IF conditions in function "train" it will have local variable referenced before assignment.
Any suggestions or solutions?
Thank you!
I draw the figure of target density a) Independent Gaussian mixture with 9 modes. I found that the standard deviation for the 3 components is 0.1 instead of 0.5 written in the paper.
I run the following code and get the same fig as shown in Fig2 A in the paper.
import torch
import matplotlib.pyplot as plt
#unnormalized target density
u_z = lambda z: (torch.exp(-0.5*((z[:,1]+1)/sigma)**2)+\
torch.exp(-0.5*(z[:,1]/sigma)**2)+\
torch.exp(-0.5*((z[:,1]-1)/sigma)**2))*\
(torch.exp(-0.5*((z[:,0]+1)/sigma)**2)+\
torch.exp(-0.5*(z[:,0]/sigma)**2)+\
torch.exp(-0.5*((z[:,0]-1)/sigma)**2))
# parameters
sigma = torch.Tensor([0.1])
range_lim = 1.5
n = 201
#draw heatmap
x = torch.linspace(-range_lim, range_lim, n)
xx, yy = torch.meshgrid((x, x))
zz = torch.stack((xx.flatten(), yy.flatten()), dim=-1).squeeze()
fig,ax = plt.subplots(figsize=(3,3))
ax.pcolormesh(xx, yy, u_z(zz).view(n,n).data, cmap="Blues")
for ax in plt.gcf().axes:
ax.set_xlim(-range_lim, range_lim)
ax.set_ylim(-range_lim, range_lim)
ax.get_xaxis().set_visible(True)
ax.get_yaxis().set_visible(True)
ax.invert_yaxis()
plt.tight_layout()
plt.savefig('target_density.png')
plt.show()
Hi, I really like the idea of estimating density with neural networks and want to implement in my own research. When I tried the MNIST training example in a linux server, I encountered the following issue:
AbortedError (see above for traceback): Operation received an exception:Status: 5, message: could not create a view primitive descriptor, in file tensorflow/core/kernels/mkl_slice_op.cc:435
[[node gradients/g_net_1/concat_1_grad/Slice_1 (defined at main_density_est_img.py:103) ]]
Do you have any idea about the issue? Thank you in advance!
System: Ubuntu 22.04.1 LTS
Python: 2.7
TensorFlow: 1.13.1
Hi,
I enjoyed reading your paper, kudos!
Could you please share details of the evaluation on the simulated independent GMM dataset (dimension > 2)? Was a random seed used when generating the dataset (for reproducibility)? What function was used to compute the Spearman correlation?
It would also be great if you could share the raw data for Appendix Fig. S2.
Thanks!
I was trying to get the conditional density for MNIST dataset by running-
CUDA_VISIBLE_DEVICES=0 python main_density_est_img.py --dx 100 --dy 784 --train True --data mnist --epochs 100 --cv_epoch 50 --patience 5;
training happens successfully but not giving density at any directory . Could you help me finding out a way to get the estimated conditional density.
I'm running the codes using Tensorflow 1.13.1 (Python 2.7) on GPU in a conda environment. And I get the error:
InternalError (see above for traceback): cuDNN launch failure : input shape ([64,128,1,1])
[[node dx_net/BatchNorm/FusedBatchNorm (defined at /public/home/***/Roundtrip/model.py:37) ]]
[[node Mean_4 (defined at main_density_est.py:68) ]]
I can't quite figure out why it happened, I also tried the methods mentioned on the websites (add os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
or change the batch size) but it did not change anything. So do you have any suggestions?
Thank you!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.