Code Monkey home page Code Monkey logo

regularized-rbm's Introduction

Regularized RBM for Feature Selection and Embedding

Introduction

Directly learning the statistical dependencies between all observed variables in RBM will bring noisy information from irrelevant variables into the model. Thus, we introduce a l1-regularizer to mitigate the impact of those noisy variables specifically. To achieve this, we impose an l1 penalty on the activation probability. Here t is a very small constant, which penalizes the reconstructed observed variables that are sensitive to large values. This penalty introduces a natural way to select the most important features (correspond to observed variables in RBM). A nice feature of this penalty is that the corresponding gradient can be computed easily.

Thus, given one training data x, we need to solve the following optimization problem. This leads to our new formulation which performs the selection of observed variables for RBM:

loglikelihood

We solve this optimization problem by gradient descent (note that this is a non-convex problem and gradient descent is a default approach to solve it). By introducing this penalty term, the gradients can be rewritten as follows:

gradients

How to use it

This work is based on tensorfow-rbm. The regRBM has remained the same API as tensorfow-rbm does.

Below is a simple example on how to train a regRBM:

rbm = RegRBM(n_visible=n_x, n_hidden=1000, t=t, lam=lam, \
             learning_rate=lr, momentum=0.95, err_function="mse", \
             sample_visible=False, sigma=1.)
errs, zeros = rbm.fit(data_x, n_epoches=n_epoches, batch_size=20, \
                      shuffle=True, verbose=True)

where additional parameters t is the constant that controls the threshold for disabling variables, and lam is the factor of the regularization.

Experiments

Fitted RBM with and without designed penalty term over 2,056 crime text (including 7,038 keywords). Under same experiment settings, (a): training errors over iterations; (b): numbers of eliminated (disabled) variables over iterations, and (c): result of cross-validation over different lambda value.

gradients

Selected features: (a): the standard deviations of tf-idf intensity over 2,056 crime text; (b): the same plot as (a) but the tf-idf intensity is reconstructed by a fitted RBM with regularization by taking the raw data as input. Top 15 keywords with the highest standard deviations have been annotated by the side of corresponding bars. The x-axis is the 7,038 keywords, and the y-axis is the standard deviations of each keyword.

gradients

References

regularized-rbm's People

Contributors

meowoodie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.