Code Monkey home page Code Monkey logo

sugartensor's Introduction

Sugar Tensor - A slim tensorflow wrapper that provides syntactic sugar for tensor variables

Sugar Tensor aims to help deep learning researchers/practitioners. It adds some syntactic sugar functions to tensorflow to avoid tedious repetitive tasks. Sugar Tensor was developed under the following principles:

Principles

  1. Don't mess up tensorflow. We provide no wrapping classes. Instead, we use a tensor itself so that developers can program freely as before with tensorflow.
  2. Don't mess up the python style. We believe python source codes should look pretty and simple. Practical deep learning codes are very different from those of complex GUI programs. Do we really need inheritance and/or encapsulation in our deep learning code? Instead, we seek for simplicity and readability. For that, we use pure python functions only and avoid class style conventions.

Installation

  1. Requirements

    1. tensorflow >= rc0.10
    2. tqdm >= 4.8.4 ( for a console progressive bar )
  2. Installation

pip install sugartensor

Quick start

###Imports

import sugartensor as tf   # no need of 'import tensorflow'

Features

Sugar functions

All tensors--variables, operations, and constants--automatically have sugar functions which start with 'sg_' to avoid name space chaos. :-)

Chainable object syntax

Inspired by prettytensor library, we support chainable object syntax for all sugar functions. This should improve productivity and readability. Look at the following snippet.


logit = (tf.placeholder(tf.float32, shape=(BATCH_SIZE, DATA_SIZE))
         .sg_dense(dim=400, act='relu', bn=True)
         .sg_dense(dim=200, act='relu', bn=True)
         .sg_dense(dim=10))

All returned objects are tensors.

In the above snippet, all values returned by sugar functions are pure tensorflow's tensor variables/constants. So, the following example is completely legal.


ph = tf.placeholder(tf.float32, shape=(BATCH_SIZE, DATA_SIZE)   # <-- this is a tensor 
ph = ph.sg_dense(dim=400, act='relu', bn=True)   # <-- this is a tensor
ph = ph * 100 + 10  # <-- this is ok.
ph = tf.reshape(ph, (-1, 20, 20, 1)).conv(dim=30)   # <-- all tensorflow's function can be applied and chained.

Practical DRY (Don't repeat yourself) functions for deep learning researchers

We provide pre-defined powerful training and report functions for practical developers. The following code is a full mnist training module with saver, report and early stopping support.


# -*- coding: utf-8 -*-
import sugartensor as tf

# MNIST input tensor ( with QueueRunner )
data = tf.sg_data.Mnist()

# inputs
x = data.train.image
y = data.train.label

# create training graph
logit = (x.sg_flatten()
         .sg_dense(dim=400, act='relu', bn=True)
         .sg_dense(dim=200, act='relu', bn=True)
         .sg_dense(dim=10))

# cross entropy loss with logit ( for training set )
loss = logit.sg_ce(target=y)

# accuracy evaluation ( for validation set )
acc = (logit.sg_reuse(input=data.valid.image).sg_softmax()
       .sg_accuracy(target=data.valid.label, name='val'))

# train
tf.sg_train(loss=loss, eval_metric=[acc], ep_size=data.train.num_batch)

You can check all statistics through the tensorboard's web interface like the following.

If you want to write another more complex training module without repeating saver, report, or whatever, you can do that like the following.


# def alternate training func
@tf.sg_train_func   # <-- sugar annotator for training function wrapping
def alt_train(sess, opt):
    l_disc = sess.run([loss_disc, train_disc])[0]  # training discriminator
    l_gen = sess.run([loss_gen, train_gen])[0]  # training generator
    return np.mean(l_disc) + np.mean(l_gen)
    
# do training
alt_train(log_interval=10, ep_size=data.train.num_batch, early_stop=False, save_dir='asset/train/gan')    

Please see the example codes in the 'sugartensor/example/' directory.

Author

Namju Kim ([email protected]) at Jamonglabs Co., Ltd.

sugartensor's People

Contributors

buriburisuri avatar

Watchers

James Cloos avatar Chagge avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.