Code Monkey home page Code Monkey logo

contrastive-predictive-coding's Introduction

Representation Learning with Contrastive Predictive Coding

This repository contains a Keras implementation of the algorithm presented in the paper Representation Learning with Contrastive Predictive Coding.

The goal of unsupervised representation learning is to capture semantic information about the world, recognizing patterns in the data without using annotations. This paper presents a new method called Contrastive Predictive Coding (CPC) that can do so across multiple applications. The main ideas of the paper are:

  • Contrastive: it is trained using a contrastive approach, that is, the main model has to discern between right and wrong data sequences.
  • Predictive: the model has to predict future patterns given the current context.
  • Coding: the model performs this prediction in a latent space, transforming code vectors into other code vectors (in contrast with predicting high-dimensional data directly).

CPC has to predict the next item in a sequence using only an embedded representation of the data, provided by an encoder. In order to solve the task, this encoder has to learn a meaningful representation of the data space. After training, this encoder can be used for other downstream tasks like supervised classification.

CPC algorithm

To train the CPC algorithm, I have created a toy dataset. This dataset consists of sequences of modified MNIST numbers (64x64 RGB). Positive sequence samples contain sorted numbers, and negative ones random numbers. For example, let's assume that the context sequence length is S=4, and CPC is asked to predict the next P=2 numbers. A positive sample could look like [2, 3, 4, 5]->[6, 7], whereas a negative one could be [1, 2, 3, 4]->[0, 8]. Of course CPC will only see the patches, not the actual numbers.

Disclaimer: this code is provided as is, if you encounter a bug please report it as an issue. Your help will be much welcomed!

Results

After 10 training epochs, CPC reports a 99% accuracy on the contrastive task. After training, I froze the encoder and trained a MLP on top of it to perform supervised digit classification on the same MNIST data. It achieved 90% accuracy after 10 epochs, demonstrating the effectiveness of CPC for unsupervised feature extraction.

Usage

  • Execute python train_model.py to train the CPC model.
  • Execute python benchmark_model.py to train the MLP on top of the CPC encoder.

Requisites

References

contrastive-predictive-coding's People

Contributors

davidtellez avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.