Code Monkey home page Code Monkey logo

pixel-rnn-tensorflow's Introduction

PixelCNN & PixelRNN in TensorFlow

TensorFlow implementation of Pixel Recurrent Neural Networks. This implementation contains:

model

  1. PixelCNN
  • Masked Convolution (A, B)
  1. PixelRNN
  • Row LSTM (in progress)
  • Diagonal BiLSTM (skew, unskew)
  • Residual Connections
  • Multi-Scale PixelRNN (in progress)
  1. Datasets
  • MNIST
  • cifar10 (in progress)
  • ImageNet (in progress)

Requirements

Usage

First, install prerequisites with:

$ pip install tqdm gym[all]

To train a pixel_rnn model with mnist data (slow iteration, fast convergence):

$ python main.py --data=mnist --model=pixel_rnn

To train a pixel_cnn model with mnist data (fast iteration, slow convergence):

$ python main.py --data=mnist --model=pixel_cnn --hidden_dims=64 --recurrent_length=2 --out_hidden_dims=64

To generate images with trained model:

$ python main.py --data=mnist --model=pixel_rnn --is_train=False

Samples

Samples generated with pixel_cnn after 50 epochs.

generation_2016_08_01_16_40_28.jpg

Training details

Below results uses two different parameters

[1] --hidden_dims=16 --recurrent_length=7 --out_hidden_dims=32
[2] --hidden_dims=64 --recurrent_length=2 --out_hidden_dims=64

Training results of pixel_rnn with [1] (yellow) and [2] (green) with epoch as x-axis:

pixel_rnn

Training results of pixel_cnn with [1] (orange) and [2] (purple) with epoch as x-axis:

pixel_cnn

Training results of pixel_rnn (yellow, green) and pixel_cnn (orange, purple) with hour as x-axis:

pixel_rnn_cnn_relative

References

Author

Taehoon Kim / @carpedm20

pixel-rnn-tensorflow's People

Contributors

carpedm20 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pixel-rnn-tensorflow's Issues

where is activation_fn defined?

Hi,everyone
I see activation_fn function in ops.py, but I can't find where it's defined and implemented.

I'm also confused about this function,
def binarize(images):
return (np.random.uniform(size=images.shape) < images).astype('float32')

Sorry, I'm really new to this architecture, could someone explain these two functions above?
Thanks all.

custom dataset as input?

I have about 15,000 images that are all 16x16 in size. They are jpeg formatted. How would I go about using them as training data? I could not find anything looking through the code, but its possible I missed something. MNIST is great and all, but I would like to use my own images for training.

Thanks.

Trained weights for MNIST

Hi All!

I wonder how to acquire pretrained network or weights file to run the things solely to the inference?
Somehow we have faced some difficulties with training the net under tensorflow 1.12,
it said the graph could not be sorted in topological order and training seems to stall.

Thanks in advance.

mask A when apply for RGB channel

in your code about mask:
mask = np.ones((kernel_h, kernel_w, channel, num_outputs), dtype=np.float32)

  mask[center_h, center_w+1: ,: ,:] = 0.
  mask[center_h+1:, :, :, :] = 0.

  if mask_type == 'a':
    mask[center_h,center_w,:,:] = 0.

but I confused that when apply for RGB channel,for example when predict G channel,
above code means the current pixel only relate to the generated pixels left and above in all three R G B channel, but in paper write "when predicting G channel, the value of the R channel can also be used as context in addition to the previously generated pixels." so is it right to contain all R channel pixel?

IOError: [Errno 101] Network is unreachable

Hi, @carpedm20

When I run the command:

$ python main.py --data=mnist --model=pixel_rnn

Its output is as follows:

I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcublas.so locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcudnn.so locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcufft.so locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:108] successfully opened CUDA library libcurand.so locally
{'batch_size': 100,
'data': 'mnist',
'data_dir': 'data',
'display': False,
'grad_clip': 1,
'hidden_dims': 16,
'is_train': True,
'learning_rate': 0.001,
'log_level': 'INFO',
'max_epoch': 100000,
'model': 'pixel_rnn',
'out_hidden_dims': 32,
'out_recurrent_length': 2,
'random_seed': 123,
'recurrent_length': 7,
'sample_dir': 'samples',
'save_step': 1000,
'test_step': 100,
'use_dynamic_rnn': False,
'use_gpu': True,
'use_residual': False}
[09-22 12:18:55] Skip creating directory: data/mnist
[09-22 12:18:55] Skip creating directory: samples/mnist/checkpoints/data=mnist/batch_size=100/grad_clip=1/hidden_dims=16/learning_rate=0.001/model=pixel_rnn/out_hidden_dims=32/out_recurrent_length=2/recurrent_length=7/use_dynamic_rnn=False/use_gpu=True/use_residual=False/
Traceback (most recent call last):
File "main.py", line 139, in
tf.app.run()
File "/home/clu/.local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
sys.exit(main(sys.argv))
File "main.py", line 71, in main
mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)
File "/home/clu/.local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py", line 189, in read_data_sets
SOURCE_URL + TRAIN_IMAGES)
File "/home/clu/.local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 157, in maybe_download
urllib.request.urlretrieve(source_url, temp_file_name)
File "/usr/lib/python2.7/urllib.py", line 98, in urlretrieve
return opener.retrieve(url, filename, reporthook, data)
File "/usr/lib/python2.7/urllib.py", line 245, in retrieve
fp = self.open(url, data)
File "/usr/lib/python2.7/urllib.py", line 213, in open
return getattr(self, name)(url)
File "/usr/lib/python2.7/urllib.py", line 350, in open_http
h.endheaders(data)
File "/usr/lib/python2.7/httplib.py", line 1053, in endheaders
self._send_output(message_body)
File "/usr/lib/python2.7/httplib.py", line 897, in _send_output
self.send(msg)
File "/usr/lib/python2.7/httplib.py", line 859, in send
self.connect()
File "/usr/lib/python2.7/httplib.py", line 836, in connect
self.timeout, self.source_address)
File "/usr/lib/python2.7/socket.py", line 575, in create_connection
raise err
IOError: [Errno socket error] [Errno 101] Network is unreachable

Could you kindly give any suggestion about it? Thanks.

Sequential generation error

Current implementation successfully works as a autoencoder but it can't generate from the beginning, from a blank image. This error should be related to masking.

Trained Weights

Not really an issue (tell me a better place to put this plz), does anyone have a trained model I can use, perhaps trained on ImageNet or Cifar :)

generated sample become zero

I have followed mater branch and trained like this
$ python main.py --data=mnist --model=pixel_rnn

But after about 400 epochs, the pixel vanished and the generated picture looks like a empty picture:
sample_2016_10_15_12_50_09

What's wrong with my procedure?

What is input and desired output

Thank you for your helpful explanation and codes, I don't understand what is input and target. Would you please tell me what are them.

Are the original images as input? or half of them?

output number of conv

when 7x7 conv mask A,what is the filter output number? and 3x3 conv layers,mask B,what is the filter output number?is it as paper said"the pixel cnn has 15 layers and h=128",does it equal h?and for relu followwd by 1x1 conv,what is the filter output number? I did not find parmters for this in paper,how did you set it?thanks

Image data format

In network.py

if conf.use_gpu:
      data_format = "NHWC"
    else:
      data_format = "NCHW"

I think NCHW is preferred for GPU, while for CPU, we should use NHWC.

data_format == "NCHW"

in your network.py:
if data_format == "NHWC":
input_shape = [None, height, width, channel]
elif data_format == "NCHW":
input_shape = [None, height, width, channel]
else:
raise ValueError("Unknown data_format: %s" % data_format)
when data_format == "NCHW" input_shape =[None, channel, height, width] is right?

add two output maps in diagonal BiLSTM

output_state_bw_with_last_zeros = tf.concat(1, [output_state_bw_except_last, dummy_zeros])

should be

output_state_bw_with_last_zeros = tf.concat(1, [dummy_zeros, output_state_bw_except_last])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.