Code Monkey home page Code Monkey logo

image-captioning-cnn-rnn's Introduction

Image Captioning with CNN and RNN

Tensorflow/Keras implementation of an image captioning neural network, using CNN and RNN.

Description

This is an unofficial implementation of the image captioning model proposed in the paper "Show and tell: A neural image caption generator.".

This implementation is a faithful reproduction of the technique proposed in the paper, where during training we provide the model with examples composed of an image, its caption (inputs) and the caption with words shifted one position to the left (ground-truth); this approach is faster during training and easier to understand than the other possible approach based on feeding prefixes of the captions to the model and predicting a single word for prefix example. The input caption is needed since we apply teacher forcing while training.

The following diagram shows the components of the model during training (in red the losses at each timestep).

diagram of the model during training

The project is organized in such a way as to make it easy to modify the Keras model and provides scripts for training, evaluating, and using the model. It is possible to download the Flickr8k dataset using the Kaggle API by running a project script and transform the dataset using the Tensorflow Data API, which manages the dataset split based on the official train, test, and validation set divisions.

Getting Started

Install the Python requirements in your virtual environment.

The dataset used is Flickr8k, containing 8000 images, where each image is associated with five different captions.

From the project root you can run

python src/data/download_dataset.py

to download the dataset from Kaggle. You will need to setup your Kaggle username and API key locally (instructions). Check if the dataset has been correctly downloaded inside /data/raw.

You can then run

python src/data/make_dataset.py

to transform the captions into sequences using the Spacy custom tokenizer contained in this project and store the train/val/test sets inside /data/processed. A custom Spacy tokenizer has been preferred to the Tensorflow tokenizer, since Spacy has support for multiple languages and it's possible to adapt the behavior of this custom implementation to different use cases.

The splits will be stored as TFRecords containing examples made of the images together with their five captions. Using the TFRecords will make it easy to use Tensorflow's Data API to train and evaluate the model. Each TFRecord contains 200 examples. The ids of the images in the split are stored in the .txt files in /data/raw, already loaded in the repository. This should be the official split proposed by the creators of Flickr8k.

You can modify the model in the source at models/model.py and change the training parameters in models/train_model.py. Then training can be started with

python src/models/train_model.py

and it will show the loss and accuracy (considering teacher forcing) of the model at training time. It will also compute the same values on the validation dataset. The configuration and the weights of the trained model are saved inside /models/config and /models/weights (with the same filename - except for the extension). By saving the model confiuration, you can experiment with different training and model options.

The default training configuration is inspired to the one proposed in the paper, but different options can be provided via command-line when running the training script. You can check the training options to configure the model and the training process via

python src/models/train_model.py --help

To make predictions over custom images, insert your *.jpg files inside the /data/custom directory. Then run

python src/models/predict_model.py

to show the generated captions in the terminal. Use the option --model_filename to specify the filename (without extension) of the model you want to restore from the /models/config folder. If the option is unspecified, the last model (alphabetical and chronological order if the names are not changed) will be loaded.

Finally, evaluate the model with

python src/models/evaluate_model.py

to compute the BLEU-1 and BLEU-4 scores for the model on the Flickr8k test or val set (specify the --model_filename and --mode option).

Sample Predictions

Sample predictions on the Flickr8k validation set after 15 epochs with the default model configuration, with a beam width of 3:

sample validation predictions

Notice that the validation loss was still dropping after 15 epochs, so model training could have been continued to achieve better results. Nevertheless, the model seems to be learning some useful patterns for describing the images.

Project Organization

├── LICENSE
├── Makefile           <- Makefile with commands like `make data` or `make train`
├── README.md          <- The top-level README for developers using this project.
├── data
│   ├── external       <- Data from third party sources.
│   ├── interim        <- Intermediate data that has been transformed.
│   ├── processed      <- The final, canonical data sets for modeling.
│   └── raw            <- The original, immutable data dump.
│
├── docs               <- A default Sphinx project; see sphinx-doc.org for details
│
├── models             <- Trained and serialized models, model predictions, or model summaries
│
├── notebooks          <- Jupyter notebooks. Naming convention is a number (for ordering),
│                         the creator's initials, and a short `-` delimited description, e.g.
│                         `1.0-jqp-initial-data-exploration`.
│
├── references         <- Data dictionaries, manuals, and all other explanatory materials.
│
├── reports            <- Generated analysis as HTML, PDF, LaTeX, etc.
│   └── figures        <- Generated graphics and figures to be used in reporting
│
├── requirements.txt   <- The requirements file for reproducing the analysis environment, e.g.
│                         generated with `pip freeze > requirements.txt`
│
├── setup.py           <- makes project pip installable (pip install -e .) so src can be imported
├── src                <- Source code for use in this project.
│   ├── __init__.py    <- Makes src a Python module
│   │
│   ├── data           <- Scripts to download or generate data
│   │   └── make_dataset.py
│   │
│   ├── features       <- Scripts to turn raw data into features for modeling
│   │   └── build_features.py
│   │
│   ├── models         <- Scripts to train models and then use trained models to make
│   │   │                 predictions
│   │   ├── predict_model.py
│   │   └── train_model.py
│   │
│   └── visualization  <- Scripts to create exploratory and results oriented visualizations
│       └── visualize.py
│
└── tox.ini            <- tox file with settings for running tox; see tox.readthedocs.io

Project based on the cookiecutter data science project template. #cookiecutterdatascience

Contributing

Contributions of any type (code, docs, suggestions...) are highly appreciated! Feel free to open an issue and to ask questions with the question label.

References

[1] Vinyals, Oriol, et al. "Show and tell: A neural image caption generator." Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

[2] Tanti, Marc, Albert Gatt, and Kenneth P. Camilleri. "Where to put the image in an image caption generator." Natural Language Engineering 24.3 (2018): 467-489.

image-captioning-cnn-rnn's People

Contributors

israelabebe avatar nicolafan avatar vijaybirju avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

israelabebe

image-captioning-cnn-rnn's Issues

Long evaluation time

Should think about some way to make evaluation (computation of the BLEU scores over different beam widths) faster, since yesterday it required around 4 hours for an evaluation on a grid search over 3 beam widths.

Load model weights

Load some model weights so that users can directly use the model without previous training.

"Separate" the encoder from the entire model

Use the Subclassing API to divide the ShowAndTell model into two sub-models: the encoder and the decoder. This is required since at prediction time the model is applied at each step of beam search on the same image over and over again. The image encoding is unique and so it can be calculated just once, at the beginning of the process.

Update the Makefile

I create this issue just to remember to update it when a clear pipeline, with CLI arguments etc., is defined.

Implement model prediction

Implement the content of the predict_model.py file.

In the file a stateful model should be loaded to use it for prediction.

The current implementation requires that an image has to be given as an input to the model together with a sequence of length MAX_CAPTION_LENGTH, where the first element of the sequence is 1 *corresponding to the <start>token) and all the others are zeros (for masking).

Then we take the model output at the first timestep, corresponding to the softmax probability distribution over the vocabulary, and sample one word from this probability distribution (remember that here the 0-th neuron corresponds to the token <start> which has index 1 in the vocabulary and so on). Sampling can be performed:

  • by taking a random element according to the probability distribution
  • by taking the argmax of the probability distribution
  • with a beam search.

These three methods should all be implemented.

When we sample a word, we create the input for the next timestep, made of the same image (which will be encoded by the model but not used - adjust this) and a new caption sequence where the first element is the index of the sampled word and all the other elements are 0. The fact that the model is stateful, means that it will keep the state it had at the previous timestep and can continue the prediction of the sequence without problems. We will sample until MAX_CAPTION_LENGTH or until <end> gets produced by the model.

Improve (?) implementation of the BLEU metric

The BLEU metric is a numerical metric used in image captioning.

They need to be implemented inside the src/models/metrics.py file (not sure if it is the correct place by the way).
I think this should not be implemented as a tensor metric that can be used by Tensorflow, but as a metric that has to be applied directly to strings.

Basically, we will provide the ground truth caption string and a string predicted by the model. How the prediction string is produced depends on the implementation to try: sampling, beam search, or max likelihood, but is not of interest for the BLEU Implementation.

Support mini-batches at prediction time

This issue is related to the performances that we want to improve with #1 .

The current implementation of the beam search only works with batch size equal to 1. There are some slight changes that can be made to support mini-batches with a size greater than 1.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.