Code Monkey home page Code Monkey logo

heatmaps's Introduction

heatmaps

Important

This repository is unmaintained, use at your own risk, it works only for multi-backend keras. Better alternative like gradient based methods provide a more accurate result.

This repository contain functions that transform a functionnal keras model that classifies images into an heatmap generator.

Some code and a lot of great ideas come from this repository: https://github.com/heuritech/convnets-keras

The heatmaps resolution are quite limited (but still better than in the heuritech repository if the model has a flatten layer).

This code should work with Theano, Tensorflow, CNTK and with all data formats, but it was only tested with Tensorflow.

Installation

Now installable with pip!

git clone https://github.com/gabrieldemarmiesse/heatmaps.git
cd heatmaps
pip install -e .

Example with VGG16

Here is a sample of code to understand what is going on:

import matplotlib.pyplot as plt
import numpy as np
from keras.applications.imagenet_utils import preprocess_input
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras import backend as K

from heatmap import to_heatmap, synset_to_dfs_ids


def display_heatmap(new_model, img_path, ids, preprocessing=None):
    # The quality is reduced.
    # If you have more than 8GB of RAM, you can try to increase it.
    img = image.load_img(img_path, target_size=(800, 1280))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    if preprocessing is not None:
        x = preprocess_input(x)

    out = new_model.predict(x)

    heatmap = out[0]  # Removing batch axis.

    if K.image_data_format() == 'channels_first':
        heatmap = heatmap[ids]
        if heatmap.ndim == 3:
            heatmap = np.sum(heatmap, axis=0)
    else:
        heatmap = heatmap[:, :, ids]
        if heatmap.ndim == 3:
            heatmap = np.sum(heatmap, axis=2)

    plt.imshow(heatmap, interpolation="none")
    plt.show()


model = VGG16()
new_model = to_heatmap(model)

s = "n02084071"  # Imagenet code for "dog"
ids = synset_to_dfs_ids(s)
display_heatmap(new_model, "./dog.jpg", ids, preprocess_input)

The function to_heatmap also take a second argument: input_shape

This should be used only if your classifier doesn't have fixed sizes for width and height. You must then give the image size that was used during training. eg: to_heatmap(model, input_shape=(3,256,256))

Example with ResNet50

You can also try this:

from keras.applications.resnet50 import ResNet50, preprocess_input
from heatmap import to_heatmap, synset_to_dfs_ids
model = ResNet50()
new_model = to_heatmap(model)

s = "n02084071"  # Imagenet code for "dog"
ids = synset_to_dfs_ids(s)
display_heatmap(new_model, "./dog.jpg", ids, preprocess_input)

Example with your own model

It should also work with custom classifiers. Let's say your classifier has two classes: dog (first class) and not dog (second class). Then this code can get you a heatmap:

new_model = to_heatmap(my_custom_model)
idx = 0  # The index of the class you care about, here the first one.
display_heatmap(new_model, "./dog.jpg", idx)

Note on the sizes of the heatmaps

Due to the topology of common classification neural networks, the heatmap produced will be smaller than the input image. The downsampling usually happen at maxpool layers or at strided convolution layers.

Here is a table to get an idea of the size of the heatmap that you will obtain.

The size of the input image is assumed to be 1024x1024.

Network Heatmap size for a 1024 x 1024 image
VGG16 51 x 51
VGG19 51 x 51
ResNet50 26 x 26
InceptionV3 30 x 30
Xception 32 x 32
InceptionResnetV2 30 x 30
MobileNet 32 x 32
DenseNet121 32 x 32

The VGG16 and 19 have a better resolution because we can use a trick before the flatten layer, and replace the convolutions by dilated convolutions.

This library performs this optimization out of the box, without you having to do anything.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.