Code Monkey home page Code Monkey logo

nce's Introduction

Noise Contrastive Estimation (NCE)

Introduction

This is an implementation of Noise Contrastive Estimation (NCE) in PyTorch on 2D dataset.

NCE is a method to estimate energy based models (EBM)

$$p_\theta(x) = \frac{\exp[-f_\theta(x)]}{Z(\theta)}$$

where

$$Z(\theta) = \int\exp[-f_\theta(x)]dx$$

is the normalizing constant that is hard to compute. In NCE, the normalizing constant is treated as a trainable parameter $c=\log Z$. We cannot directly do maximum likelihood estimation (MLE) with $\displaystyle\max_\theta p_\theta(x)$ because $p_\theta(x)$ can simply blow up to infinity by letting $Z\to0$ (or $c\to -\infty$). Instead, in Noise Contrastive Estimation, we train the energy based model by doing (nonlinear) logistic regression/classification between the data distribution $p_{\mathrm{data}}$ and some noise distribution $q$.

There are three requirements for the noise distribution $q$:

  (1) log density can be evaluated on any input;

  (2) samples can be obtained from the distribution;

  (3) $q(x)\neq0$ for all $x$ such that $p_{\mathrm{data}}(x)\neq0$.

Here we use Multivariate Gaussian as the noise distribution.

The objective is to maximize the posterior log-likelihood of the classification

$$V(\theta) = \mathbb{E}_{x\sim p_{\text{data}}}\log\frac{p_\theta(x)}{p_\theta(x)+q(x)} + \mathbb{E}_{\tilde{x}\sim q}\log\frac{q(\tilde{x})}{p_\theta(\tilde{x}) + q(\tilde{x})}.$$

This objective is implemented in the file util.py as the value function (we minimize $-V(\theta)$). We use Adam as the optimizer.

Installation

Clone the repository to your local machine with

git clone https://github.com/lifeitech/nce.git

In your python environment, cd to the repository, and

pip install -r requirements.txt

Training

To train the model, do

python trian.py

For MacOS users, since currently PyTorch only has limited support for mps, make sure to run the script with PYTORCH_ENABLE_MPS_FALLBACK=1. You can add

export PYTORCH_ENABLE_MPS_FALLBACK=1 

to your .zshrc file.

Available datasets:

  • 8gaussians (default)
  • 2spirals
  • checkerboard
  • rings
  • pinwheel

A density plot is saved in the folder images after every epoch. After training, you can obtain gif images like below by executing the python script in the folder:

cd images
python create_gif.py

Examples

Some visualizations of the learned energy densities are listed below.

  • 8gaussians dataset

8gaussians

  • pinwheel dataset

pinwheel

  • 2spirals dataset

2spirals

nce's People

Contributors

lifeitech avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.