Code Monkey home page Code Monkey logo

tfsnippet's Introduction

TFSnippet

Stable stable_build stable_cover stable_doc
Develop develop_build develop_cover develop_doc

TFSnippet is a set of utilities for writing and testing TensorFlow models.

The design philosophy of TFSnippet is non-interfering. It aims to provide a set of useful utilities, possible to be used along with any other TensorFlow libraries and frameworks.

Dependencies

TensorFlow >= 1.5

Installation

pip install git+https://github.com/haowen-xu/tfsnippet.git

Documentation

Examples

Quick Tutorial

From the very beginning, you might import the TFSnippet as:

import tfsnippet as spt

Distributions

If you use TFSnippet distribution classes to obtain random samples, you shall get enhanced tensor objects, from which you may compute the log-likelihood by simply calling log_prob().

normal = spt.Normal(0., 1.)
# The type of `samples` is :class:`tfsnippet.stochastic.StochasticTensor`.
samples = normal.sample(n_samples=100)
# You may obtain the log-likelhood of `samples` under `normal` by:
log_prob = samples.log_prob()
# You may also obtain the distribution instance back from the samples,
# such that you may fire-and-forget the distribution instance!
distribution = samples.distribution

The distributions from ZhuSuan can be casted into a TFSnippet distribution class, in case we haven't provided a wrapper for a certain ZhuSuan distribution:

import zhusuan as zs

uniform = spt.as_distribution(zs.distributions.Uniform())
# The type of `samples` is :class:`tfsnippet.stochastic.StochasticTensor`.
samples = uniform.sample(n_samples=100)

Data Flows

It is a common practice to iterate through a dataset by mini-batches. The tfsnippet.DataFlow provides a unified interface for assembling the mini-batch iterators.

# Obtain a shuffled, two-array data flow, with batch-size 64.
# Any batch with samples fewer than 64 would be discarded.
flow = spt.DataFlow.arrays(
    [x, y], batch_size=64, shuffle=True, skip_incomplete=True)
for batch_x, batch_y in flow:
    ...  # Do something with batch_x and batch_y

# You may use a threaded data flow to prefetch the mini-batches
# in a background thread.  The threaded flow is a context object,
# where exiting the context would destroy the background thread.
with flow.threaded(prefetch=5) as threaded_flow:
    for batch_x, batch_y in threaded_flow:
        ...  # Do something with batch_x and batch_y

# If you use `MLSnippet <https://github.com/haowen-xu/mlsnippet>`_,
# you can even load data from a MongoDB via data flow.  Suppose you
# have stored all images from ImageNet into a GridFS (of MongoDB),
# along with the labels stored as ``metadata.y``.
# You may iterate through the ImageNet in batches by:
from mlsnippet.datafs import MongoFS

fs = MongoFS('mongodb://localhost', 'imagenet', 'train')
with fs.as_flow(batch_size=64, with_names=False, meta_keys=['y'],
                shuffle=True, skip_incomplete=True) as flow:
    for batch_x, batch_y in flow:
        ...  # Do something with batch_x and batch_y.  batch_x is the
             # raw content of images you stored into the GridFS.

Training

After you've build the model and obtained the training operation, you may quickly run a training-loop by using utilities from TFSnippet:

input_x = ...  # the input x placeholder
input_y = ...  # the input y placeholder
loss = ...  # the training loss
params = tf.trainable_variables()  # the trainable parameters

# We shall adopt learning-rate annealing, the initial learning rate is
# 0.001, and we would anneal it by a factor of 0.99995 after every step.
learning_rate = spt.AnnealingVariable('learning_rate', 0.001, 0.99995)

# Build the training operation by AdamOptimizer
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss, var_list=params)

# Build the training data-flow
train_flow = spt.DataFlow.arrays(
    [train_x, train_y], batch_size=64, shuffle=True, skip_incomplete=True)
# Build the validation data-flow
valid_flow = spt.DataFlow.arrays([valid_x, valid_y], batch_size=256)

with spt.TrainLoop(params, max_epoch=max_epoch, early_stopping=True) as loop:
    trainer = spt.Trainer(loop, train_op, [input_x, input_y], train_flow,
                          metrics={'loss': loss})
    # Anneal the learning-rate after every step by 0.99995.
    trainer.anneal_after_steps(learning_rate, freq=1)
    # Do validation and apply early-stopping after every epoch.
    trainer.evaluate_after_epochs(
        spt.Evaluator(loop, loss, [input_x, input_y], valid_flow),
        freq=1
    )
    # You may log the learning-rate after every epoch registering an
    # event handler.  Surely you may also add any other handlers.
    trainer.events.on(
        EventKeys.AFTER_EPOCH,
        lambda epoch: trainer.loop.collect_metrics(lr=learning_rate),
    )
    # Print training metrics after every epoch.
    trainer.log_after_epochs(freq=1)
    # Run all the training epochs and steps.
    trainer.run()

tfsnippet's People

Contributors

haowen-xu avatar korepwx avatar tsinghuasuya avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

tfsnippet's Issues

Request: Access to v0.2.0-alpha1

Thank you for this repo.

This repo's requirements.txt makes use use of the v0.2.0-alpha1 version of your tfsnippet repo. Is it possible to put this version back up please?

Thanks in advance.

安装的问题

安装的时候提示

ERROR: Could not find a version that satisfies the requirement frozendict>=1.2.0 (from TFSnippet==0.1.2) (from versions: none)
ERROR: No matching distribution found for frozendict>=1.2.0 (from TFSnippet==0.1.2)

0.0求解答

ModuleNotFoundError: No module named 'tfsnippet.modules'

File "D:\Google_\DeepADoTS-master\DeepADoTS-master\src\algorithms\donut.py", line 7, in
from donut import DonutTrainer, DonutPredictor, Donut as DonutModel, complete_timestamp, standardize_kpi
File "D:\Useful_Program\Python\lib\site-packages\donut_init_.py", line 4, in
from .model import *
File "D:\Useful_Program\Python\lib\site-packages\donut\model.py", line 6, in
from tfsnippet.modules import VAE, Lambda, Module
ModuleNotFoundError: No module named 'tfsnippet.modules'

I cant find "modules" in the tfsnippet package?
(my tf version is 1.15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.