Code Monkey home page Code Monkey logo

celeste's Introduction

Project of Team Celeste in Google AI ML Winter Camp Shanghai

14 - 18 January, 2019

Project Name: Hinted Quick Draw

Team Members: Minjun Li & Shaoxiang Chen, Fudan University

This project wins the final Best Project Award. Many thanks to Google and Tensorflow🌞!

1. Introduction

Our project is mainly exploring the Quick Draw dataset and improving the Quick Draw game based on our findings. In the Quick Draw game, the player is asked to draw an object of a specified class with several sketches. Thus we first need a recognition model to classify sketches. However, our focus is not to train a super accurate model to recognize the drawings, but to perform interesting analysis based on our trained model which has reasonable performance. We first explored training different models (including RNN & CNN) to recognize drawings in the Quick Draw dataset. We found that a hand-crafted CNN can be trained in a short time and achieve reasonable accuracy (~70% accuracy is enough to play the game).
With a trained CNN, we are able to perform various interesting analysis, such as: inter-class similarity analysis to find out which classes are easily mixed up with others (including t-SNE visualization and confusing pairs analysis), CNN class activation map visualization for interpretability of how the classification decision is made by the CNN, definitive stroke analysis and visualization which finds specific strokes that push the CNN’s prediction towards desired class.
Finally, based on our analysis, we try to make the Quick Draw game more interesting by hint the players of Quick Draw with (1) our CNN and (2) Sketch RNN. In (1), the player gets hints about whether the current stroke he/she draws makes the drawing more like the object of desired class. In (2), the player gets a direct hint from Sketch RNN about what the next stroke should be. Technics from papers[1,2] are used in our work.

Slides (demo videos inside!) and Poster are available in Google Drive.

2. About the Codes

.   
β”œβ”€β”€ cluster                             # clustering analysis
β”‚Β Β  β”œβ”€β”€ analysis.ipynb                      # notebook for inter-class similarity analysis
β”‚Β Β  β”œβ”€β”€ class_id_map.pkl                    # file containing class label --> id mapping
β”‚Β Β  β”œβ”€β”€ extract_feature.py                  # script to extract features of validation images from trained CNN
β”‚Β Β  β”œβ”€β”€ tsne_cls100.png                     # image of t-SNE visualization of CNN features
β”‚Β Β  β”œβ”€β”€ tsne.ipynb                          # notebook for t-SNE visualization
β”‚Β Β  └── tsne.png                            # image of t-SNE visualization of CNN features
β”œβ”€β”€ common                              # common files and codes used by other scripts
β”‚Β Β  β”œβ”€β”€ class_id_map.pkl                    # file containing class label --> id mapping   
β”‚Β Β  β”œβ”€β”€ fixed_model.py                      # final CNN model for sketch recognition
β”‚Β Β  β”œβ”€β”€ model.py                            # contains various CNN models we tried
β”‚Β Β  β”œβ”€β”€ preprocessing                       # preprocessing styles for different networks
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ cifarnet_preprocessing.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ inception_preprocessing.py          # we only use inception style
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ __init__.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ lenet_preprocessing.py
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ preprocessing_factory.py
β”‚Β Β  β”‚Β Β  └── vgg_preprocessing.py
β”‚Β Β  └── utils.py                            # utility functions
β”œβ”€β”€ infer                               # things to do after having a trained CNN model
β”‚Β Β  β”œβ”€β”€ bee.png                             # sample image containing a 'bee'
β”‚Β Β  β”œβ”€β”€ best_stroke.ipynb                   # definitive stroke analysis
β”‚Β Β  β”œβ”€β”€ ckpt                                # tensorflow model checkpoints
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ classifier                          # CNN classifier
β”‚Β Β  β”‚Β Β  └── flamingo                            # sketch RNN
β”‚Β Β  β”œβ”€β”€ class_id_map.pkl
β”‚Β Β  β”œβ”€β”€ feat_extraction.ipynb               # extracts feature & original strokes for definitive stroke analysis
β”‚Β Β  β”œβ”€β”€ infer.py                            # inference utility of CNN, providing image --> class API and supports local and network mode
β”‚Β Β  β”œβ”€β”€ infer_rnn.py                        # inference utility of sketch RNN
β”‚Β Β  β”œβ”€β”€ __init__.py
β”‚Β Β  β”œβ”€β”€ sketch_ai_play.py                   # attempt to make sketch RNN play by itself
β”‚Β Β  β”œβ”€β”€ sketch_cli.py                       # Quick Draw GUI that connects to an inference server for definitive stroke hints
β”‚Β Β  β”œβ”€β”€ sketch_no_hint.py                   # Quick Draw GUI without any hint
β”‚Β Β  └── sketch.py                           # Quick Draw GUI that gets hint from skecth RNN
β”œβ”€β”€ legacy                              # legacy train & val code
β”‚Β Β  β”œβ”€β”€ train.ipynb
β”‚Β Β  └── val.ipynb
β”œβ”€β”€ LICENSE
β”œβ”€β”€ mkdata                              # make training data
β”‚Β Β  β”œβ”€β”€ class_id_map.pkl
β”‚Β Β  └── mkdata.ipynb                        # convert raw strokes into images and save as tfrecord
β”œβ”€β”€ README.md
└── trainval                            # final train & validation scripts
    β”œβ”€β”€ cnn                                 # hand-crafted CNN
    β”‚Β Β  β”œβ”€β”€ cnn_vis.ipynb                       # notebook for CNN class activation map visualization
    β”‚Β Β  β”œβ”€β”€ gp_model.py                         # CNN model with global pooling for visualization
    β”‚Β Β  β”œβ”€β”€ log
    β”‚Β Β  β”œβ”€β”€ train.py                            # training script
    β”‚Β Β  └── val.py                              # validation script
    └── rnn                                 # rnn from official tensorflow tutorial
        β”œβ”€β”€ create_dataset.py                   # save raw strokes to tfrecord
        └── train.py                            # training & validation script

3. Model Training

We trained RNN and CNN to recognize sketches. While RNN can achieve higher accuracy, it needs a long time to train. So we hand-crafted a shallow CNN instead, and it reaches reasonable perormance in a short time. Our objective is not to train a super accurate recognition model, but to explore and analysis the dataset with a goal of finding interesting insights that could help us (maybe) improve the Quick Draw game.

For training CNNs, we draw the strokes in images and resize them to 128x128. The CNN is trained with batch size 512, Adam optimizer with learning rate 0.001 and 100000 iterations. All our following analysis is based on the trained CNN.

4. Inter-Class Similarity Analysis

IPython Notebook here.

To see if the CNN feature from the 'FC 512' layers captures inter-class similarity, we first do a t-SNE visualization.

There are classes that form dense clusters as visualized, but some others scatter all over the space, which inidicates they could be very similar to other classes. To find out the confusing classes, we sort the 340 classes by their similarity to all other classes in descending order. The top ones are:

These classes all have very simple shape that is close to a rectangle. See the full list here.

For each class, the most similar class to them can be found here. And some samples here.

5. CNN Class Activation Map Visualization for Interpretability

IPython Notebook here.

To understand why the CNN made such predictions, we use the technic from [1] to compute a CNN activation map for visualizing contributions from each spatial region. Below are samples for 'cookie', 'hospital' and 'cell phone'.

It makes sense that the botton and some edges contribute the most in 'cell phone' images.

6. Definitive Stroke Analysis and Visualization

IPython Notebook here.

We further want to analyze which stroke is the most effective one that pushes the model’s decision towards the target class. The approach is: for images in each class, add strokes one-by-one and keep track of the probability of target class as it changes. The definitive stroke is the one that gives the most significant probability increase. We visualize three pairs of most similiar classes and their corresponding definitive strokes.

7. Demo

We wrote two demos in which the player gets hint about whether a stroke is good or bad and Sketch RNN. In both demos, the player is asked to draw a flamingo.
To know whether a stroke is good or bad, we track the probability of the desired class as the player is drawing. If the player's current stroke lowers that probability, it is retracted. Sketch RNN[2](architecture below) is a sequence-to-sequence model for sketch generation. We only use the decoder part and at each time feed the player's strokes into the decoder so that the model outputs next strokes as hints to the player.

References

[1] Zhou, Bolei, Aditya Khosla, Agata Lapedriza, Aude Oliva, and Antonio Torralba. "Learning deep features for discriminative localization." CVPR 2016.

[2] Ha, David, and Douglas Eck. "A neural representation of sketch drawings."Β ICLR 2018.

celeste's People

Contributors

forwchen avatar minjunli avatar

Stargazers

 avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.