Code Monkey home page Code Monkey logo

tensorflow-devise's Introduction

tensorflow-DeViSE

Attempts to understand deep learning and the Tensorflow RNN api by implementing a (very)crude version of the Google DeViSE paper(2013).

Series of events which led to this repo:

  1. Tried to understand theano.
  2. Tried to hack using theano.
  3. Tried to hack using tensorflow.

The paper's objective (as described in Section 3) is to "leverage semantic knowledge learned in the text domain, and transfer it to a model trained for visual object recognition", which is exactly what I did...sort of.

I was developing/testing this on my Dell laptop which has a ATI graphics card (so no CUDA). My definition of success for this project was to make something which (1)works and (2)decreases the loss. I'm putting the code online because I could not find beginner level code for using/hacking recurrent neural networks. This "work" is a result of repeatedly editing(and breaking) the RNN language model examples available with tensorflow. Hopefully, someone can use this to start their own RNN experiments.

How this works:

The idea was to make a image search engine: given a query(sequence of tokens), return the candidate set of images sorted from most appropriate to least appropriate. Simple enough(sarcasm). All I had to do was to encode the entire query into a single vector of fixed length, encode the images into a single vector of same length and ensure that the encoding process can capture the ground truth. By capturing the ground truth, I mean that appropriate <image, query> pairs are closer in the embedded dimension space and inappropriate <image, query> pairs are far apart. Concretely, if there was an image of a dog drinking water, its "appropriate" query pair could be "dog is driking water" and an "inappropriate" query pair could be "man driving a car". The encoding process must ensure that the appropriate <image, query> vector pair is closer than the inappropriate <image, query> pair.

Encoding query strings:

This is a two step process. First, I convert the word into vectors(see what I did there). The size of these vectors(or word-embeddings) is not related to the size of the encoded images. The paper suggests to train a language model which learns these word embeddings from scratch. Instead I use pre-trained Stanford Glove word embeddings. I chose this over Google's Word2Vec because it offered a better coverage of my training vocabulary. More on this here.

Once the words have been encoded as vectors, I had to "condense" the list of vectors into a single vector. For this I use a Recurrent Neural Network with an LSTM cell. Colah's blog is a great resource for understanding related topics. I used this to learn more about RNNs. This reddit post got me started with basic RNN code using the tensorflow RNN api. I consider the "intent/meaning" of a query to be the output returned by the RNN. The dimensions of this output should be equal to the dimensions of encoded images.

Encoding images:

Once again, a two step process. First, I extract the "best possible" feature vector from the image. The paper suggests to train a visual object recognition system based on the ILSVRC 2012 winner. Instead, I use VGG16 pre-trained weights for tensorflow, which can be downloaded from here. The features are extracted from the last fully connected layer of the network.

Once the image feautre vector has been extracted, a linear transformation maps the vector to new dimension space. The size of the encoded images should be the same as the size of the encoded queries.

Loss function:

Once encoded, the query string and the image have equal dimensions. The loss function proposed in the paper is a hinge rank loss described in Section 3.3. Its similar to hinge loss with the addition of contrastive(or inappropriate) pairs: at every training epoch, the weights are updated based on how close appropriate training pairs are and how far inappropriate pairs are.

The purpose of the model is to minimize this loss function and update the weights of all parameters using gradient descent. I used the Pascal dataset for development. It contains 5 captions per image for 20 different catagories of objects; each catagory had 50 images. While developing this, I used a subset(150 images) of dogs, cats and birds.

Code:

All the code is in new_model.py. The hyperparameters are global variables; a main function creates the model and starts training. The code for the model and tranining is half of all the code in the file, the rest is for processing data and handling I/O. I've tried to add comments regularly; it should be simple enough to read starting from the main() function. If you're feeling up to the task of hacking this for your own experiments(best of luck!!!), here's a list of python modules you'll need to get started:

  • tensorflow 0.8
  • numpy 1.11
  • cv2 2.4.8
  • skimage 0.9.3

I've seen boilerplate code where people have used Caffe for extracting image features using AlexNet, CaffeNet etc. Please feel free to use that if its easier; that part of the pipeline is only used for feature extraction(it is not trained/updated during backprop so it should be easy to substitute).

To do:

  • Add utility/interface for querying images once the model has been trained.
  • Add a better way to evaluate model training(accuracy, recall etc).

tensorflow-devise's People

Contributors

priyamtejaswin avatar

Stargazers

 avatar  avatar Mike avatar Yuqing Wang avatar Jiwen Ren avatar Jessica Yen-Yi Wu avatar  avatar

Watchers

James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.