Code Monkey home page Code Monkey logo

tf-fit's Introduction

fastai-tf-fit

Fit your Tensorflow model using fastai and PyTorch

Installation

pip install git+https://github.com/fastai/tf-fit.git

Features

This project is an extension of fastai to allow training of Tensorflow models with a similar interface of fastai. It uses fastai DataBunch objects so the interface is exactly the same for loading data. For training, the TfLearner has many of the same features as the fastai Learner. Here is a list of the currently supported features.

  • Training Tensorflow models with constant learning rate and weight decay
  • Training using the 1cycle policy
  • Learning rate finder
  • Fit with callbacks with access to hyper parameter updates
  • Discriminative learning rates
  • Freezing layers from having parameters trained
  • True weight decay option
  • L2 regularization (true_wd=False)
  • Removing weight decay from batchnorm layers option (bn_wd=False)
  • Momentum
  • Option to train batchnorm layers even if the layer is frozen (train_bn=True)
  • Model saving and loading
  • Default image data format is channels * hieght * width

To do

This project is a work in progress so there may be missing features or obscure bugs.

  • Get predictions function
  • Tensorflow train/eval functionality for dropout and batchnorm in eager mode
  • Pip and conda packages

Examples

Setup

Setup fastai data bunch, optimizer, loss function, and metrics.

from fastai.vision import *
from fastai_tf_fit import *

path = untar_data(URLs.CIFAR)
ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])
data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=512).normalize(cifar_stats)

opt_fn = tf.train.AdamOptimizer

loss_fn = tf.losses.sparse_softmax_cross_entropy

def categorical_accuracy(y_pred, y_true):
    return tf.keras.backend.mean(tf.keras.backend.equal(y_true, tf.keras.backend.argmax(y_pred, axis=-1)))
metrics = [categorical_accuracy]

Using tf.keras.Model

class Simple_CNN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization(axis=1)
        self.conv2 = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization(axis=1)
        self.conv3 = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')
        self.bn3 = tf.keras.layers.BatchNormalization(axis=1)
    def call(self, xb):
        xb = tf.nn.relu(self.bn1(self.conv1(xb)))
        xb = tf.nn.relu(self.bn2(self.conv2(xb)))
        xb = tf.nn.relu(self.bn3(self.conv3(xb)))
        xb = tf.nn.pool(xb, (4,4), 'AVG', 'VALID', data_format="NCHW")
        xb = tf.reshape(xb, (-1, 10))
        return xb

model = Simple_CNN()

Using Keras functional API

inputs = tf.keras.layers.Input(shape=(3,32,32))
x = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(inputs)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(16, kernel_size=3, strides=(2,2), padding='same')(x)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(10, kernel_size=3, strides=(2,2), padding='same')(x)
x = tf.keras.layers.BatchNormalization(axis=1)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), padding='same')(x)
x = tf.keras.layers.Reshape((10,))(x)
predictions = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.models.Model(inputs=inputs, outputs=predictions)

Training

Create TfLearner object

learn = TfLearner(data, model, opt_fn, loss_fn, metrics=metrics, true_wd=True, bn_wd=True, wd=defaults.wd, train_bn=True)

Learning rate finder.

learn.lr_find()
learn.recorder.plot()

Train the model for 3 epochs with a learning rate of 3e-3 and weight decay of 0.4.

learn.fit(3, lr=3e-3, wd=0.4)

Fit the model using 1cycle policy with a cycle length of 10 using a discriminative learning rate.

learn.fit_one_cycle(10, max_lr=slice(6e-3, 3e-3))

Freeze, unfreeze, and freeze to last layers from training.

learn.freeze()
learn.unfreeze()
learn.freeze_to(-1)

Save and load model weights.

learn.save('cnn-1')
learn.load('cnn-1')

Metrics

Plot learning rate and momentum schedules.

learn.recorder.plot_lr(show_moms=True)

Plot train and validation losses.

learn.recorder.plot_losses()

Plot metrics.

learn.recorder.plot_metrics()

tf-fit's People

Contributors

jph00 avatar pendar2 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.