Code Monkey home page Code Monkey logo

autoencoder's Introduction

Autoencoder

Try to build many kinds of autoencoder. Currently got vanilla autoencoder and variantional autoencoder only. RAE, and VRAE will be added in the future.

Example

Use the example from tensorflow example tutorial.

AE

The training process will save model in Model/AE/ folder, make sure the folder exist otherwise the files will failed to saved.
During the evaluation, pretrained model will be restored for further testing.

# train
python train_AE.py
# test
python eval_AE.py

Also, the model graph will be save in Log/AE/ folder before training. Use the following command to visualize the model graph.

# launch tensorboard
tensorboard --logdir=Log/AE

The model configure defined in config_AE.py.

VAE

Similar to AE, model saved in Model/VAE/, graph in Log/VAE/, and the model configure defined in config_VAE.py.

# train
python train_VAE.py
# test
python eval_VAE.py
# launch tensorboard
tensorboard --logdir=Log/VAE

Different to the previous example restores whole model for evaluation, this example only restores the decoder for using.

Usage

Model Config

Use config to specify model architecture.
In the config, "model" and "loss" must be defined.
There is also an optional setting "random_init" for setting how variables are initialized.

model

"model" is a list of blocks, and each block has it own "name", which will be the scope name of block.

block

Blocks can be nested by define "blocks", which is also a list of bloocks.
Also, each block has "layers", a list of layers, which would be FC, RNN, CNN, anrd etc. (Currently only FC are available.)

layer

Each layer must have its own unique "name", and what "type" it is.
Different type of layers have different attributes to fill in. Check model.py or example config file for more information.

config = {
    "random_init": custom_random_init
    "model": [
        # list of blocks
        {
            "name": "block name",
            "blocks": [
                # list of blocks
            ],
            "layers": [
                # list of layers
                {
                    "type": "FC",
                    "name": "layer_name",
                    "input": "some input layer",
                    "output_size": 100,
                    "activation": tf.nn.tanh
                },
                ...
            ]
        },
        ...
    ],
    "loss": [
        {
            "name": "loss name",
            "weight": 1,
            "ground_truth": "some label layer",
            "prediction": "some output layer",
            "loss_func": custom_loss_func
        }
    ],

}

loss

"loss" define the optimisation objective.
In default, loss function is the MSE between "ground_truth" and "prediction". If "loss_func" is defined, it will pass "ground_truth" and "prediction" to the custom loss function.
"weight" is the loss weighting to trade with other loss.
For variational autoencoder, loss of sampler, which is KL divergence, will be add to total loss automatically.

Training and Testing

from autoencoder.model import Model

# build model
model = Model(config)
# specity training parameters, not actually training
model.train(learning_rate)

# init tensorflow variables
sess.run(model.init)

# train
# need to find the tensor with its name in the graph
graph = tf.get_default_graph()
model_input = graph.get_tensor_by_name("input_node_name:0")
sess.run(model.optimizer, feed_dict={model_input: data})

# test full model
sess.run(model_output, feed_dict={model_input: data})
# test encoder
sess.run(model_encoder, feed_dict={model_input: data})
# test decoder
sess.run(model_output, feed_dict={model_decoder_input: data})

Version

tensorflow 1.7.1 tensorboard 1.7.0

Reference

Tensorflow Example

autoencoder's People

Contributors

y4cj4sul3 avatar

Watchers

 avatar

Forkers

recharrs

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.