Code Monkey home page Code Monkey logo

prototrain's Introduction

Prototrain

A training framework for running experiments in computer vision that's based on TensorFlow

Table of Contents

Background

This project provides a simple framework to experiment with and develop machine learning models.

More specifically, at Verizon Media we have used this code to perform research on the topic of "visual search" and "metric learning" within the e-commerce domain. With this repository, you can train image retrieval models to find relevant products from a catalog of thousands of products given a single query image.

We have presented this work in a poster session at the BayLearn Symposium on Oct 11, 2018. To learn more about our work, please read our technical report "Learning embeddings for Product Visual Search with Triplet Loss and Online Sampling".

The training code in this repository aims to provide a flexible framework for experimenting with low-level neural network details. While high-level APIs such as tf.estimator and tf.keras provide a quick and simple way of training standard convolutional neural networks, they abstract away many details of the training loop that could be interesting to examine within a research context.

The goal of this code base is to give researchers finer-grain controls over these details while still providing "just enough" infrastructure to track, configure, and reproduce experiments in a collaborative setting.

A few benefits of using this code for training:

  • Simple Model definition API - All models can be created by defining a model.config dictionary object and the following two functions: model.datasets() and model.build(inputs, train_phase).
  • Debug mode - Use the --debug flag when calling train.py to jump into an IPython session immediately after model setup has completed.
  • Resumable models - You can interrupt and restart the training job without worry. The trainer will pickup from the last saved checkpoint. If you want to start fresh, use the --fresh flag to delete the existing checkpoint and start from scratch.
  • Checkpoint on exit - If your job exits for any reason (including if you manually interrupt training), the model will be saved to a checkpoint before exiting.
  • Config serialization - All training runs will save the config hyper-parameter dictionary as both a human readable text as well as a machine readable format for later reference.
  • Transparent/Customizable training loop - The training loop is easy to understand to make it simple to use. You can customize your training loop by copying trainers/default.py into a new file.
  • Hyperparameter tuning - Provides the possibility for automated hyper-parameter tuning since config is nothing more than a python dictionary that may be manipulated at runtime.
  • Multi-GPU Training - Train on multiple GPUs at a time.

Install

This code currently runs in Python 2.7.11, so make sure you have this version running. We have plans to migrate to Python 3 in the future.

Before using this library, you will need to have a few prerequisates installed via pip:

pip install numpy tensorflow tensorflow-hub

Then, simply clone this repo.

git clone https://github.com/yahoo/prototrain && cd prototrain

Usage

You can start a new experiment by calling the following code from the command line.

export CUDA_VISIBLE_DEVICES=1,2
python train.py -m models.simple

This will train the model in the models/simple folder. You

Configuration

The following are configuration flags shared across all models.

Trainer settings

dtype setting description
str trainer.experiment_base_dir The directory path to store all experiments
str trainer.experiment_name The experiment name. This, combined with experiment_base_dir, will be where all experiment logs, checkpoints, and summaries will be saved. The suggested format for this string is "exp-<initials>-<descriptive_name>-<run_number>" like the following example: "exp-hn-evergreen-pascal-voc80"
int trainer.validation_step_interval Determines interval of steps between feeding a batch of validation data through session
int trainer.summary_save_interval Determines interval of steps between saving tensorboard summaries
int trainer.print_interval Determines interval of steps between printing loss
int trainer.checkpoint_save_interval Determines interval of steps between saving checkpoints
int trainer.max_checkpoints_to_keep The number of checkpoints to save at any given time
str trainer.random_seed A random seed to initialize tensorflow and numpy with

Model general settings

dtype setting description
str model.weights_load_path A path to a pretrained checkpoint that you'd like to load for fine tuning
int model.batch_size The full batch size to use for training
int model.num_steps The number of steps to run trianing
func model.trainable_variables.filter_fn A function of the form fn(v) -> Bool that returns True if variable v should be updated via gradient descent during training
str model.trainable_variables.filter_regex A regex that returns a match if a variable should be updated during training
func model.weight_decay.filter_fn A function of the form fn(v) -> Bool that returns True if variable v should be considered when computing the weight decay loss
str model.weight_decay.filter_regex A regex that returns a match if a variable should be considered when computing the weight decay loss
bool model.clip_gradients Turn on gradient clipping during training (default: False)
float model.clip_gradients.minval Minimum gradient value (default: -1.0)
float model.clip_gradients.maxval Maximum gradient value (default: 1.0)

Model optimizer settings

dtype setting description
str model.optimizer.method Choose one of the following optimizers msgd (Momentum SGD), adam, rmsprop, sgd. Based upon this value, additional settings apply below.

Momentum SGD settings:

dtype setting description
float model.optimizer.msgd.momentum Sets momentum when using msgd optimizer
bool model.optimizer.msgd.use_nesterov Uses nesterov accelerated gradients if True

ADAM settings:

dtype setting description
float model.optimizer.adam.beta1 ADAM beta1
float model.optimizer.adam.beta2 ADAM beta2
float model.optimizer.adam.epsilon ADAM epsilon

RMSProp settings:

dtype setting description
float model.optimizer.rmsprop.decay RMSProp decay rate
float model.optimizer.rmsprop.momentum RMSProp momentum
float model.optimizer.rmsprop.epsilon RMSProp epsilon

Model learning rate schedule

dtype setting description
str model.lr_schedule.method One of the following [step, exp, constant, custom]. Based upon this value, additional settings apply below.

Piecewise learning rate schedule:

dtype setting description
list(int) model.lr_schedule.step.boundaries A list of step boundary values that determine when to set the learning rate to the next learning rate value.
list(float) model.lr_schedule.step.values List of floating point values for the learning rate. This list should have len(boundaries) + 1 entries, where the first entry is the starting learning rate.

Constant learning rate schedule:

dtype setting description
float model.lr_schedule.constant.value A constant value to be used for the learning rate throughout the entire training job.

Exponential decay learning rate:

dtype setting description
float model.lr_schedule.exp.value Initial value for learning rate
float model.lr_schedule.exp.decay_rate A value to decay learning rate by
float model.lr_schedule.exp.decay_steps Number of steps to run before decaying the learning rate

Custom learning rate schedule:

dtype setting description
func model.lr_schedule.custom.fn A function of the form fn(step) -> tf.float that returns a float value that represents the learning rate at the provided step.

Contribute

Please refer to the contributing.md file for information about how to get involved. We welcome issues, questions, and pull requests. Pull Requests are welcome.

Core contributors / Maintainers

Acknowledgements

We thank the following people for their contributions, testing, and feedback on this code:

  • Jack Culpepper
  • Simao Herdade
  • Kofi Boakye
  • Andrew Kae
  • Tobias Baumgartner
  • Armin Kappeler
  • Pierre Garrigues

License

This project is licensed under the terms of the Apache 2.0 open source license. Please refer to LICENSE for the full terms.

prototrain's People

Contributors

emdodds avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.