Code Monkey home page Code Monkey logo

keras-onnx's Introduction

keras2onnx

Linux Windows
Build Status Build Status

Introduction

The keras2onnx model converter enables users to convert Keras models into the ONNX model format. Initially, the Keras converter was developed in the project onnxmltools. To support more kinds of Keras models and reduce the complexity of mixing multiple converters, keras2onnx was created to convert the Keras model only.

keras2onnx converter supports not only all keras built-in layers, but also the lambda/custom layer by working with the tf2onnx converter which is embedded into the source tree directly now to avoid version conflicts and installation complexity.

keras2onnx has been tested on Python 3.5 and 3.6 (CI build). It does not support Python 2.x.

Notes

tf.keras v.s. keras.io

Both Keras model types are supported now. If keras package was installed as the one from https://keras.io/, the converter converts the model as it was created by this keras.io package, otherwise it will convert as it was by tf.keras.
If you want to override this behaviour, please specify the environment variable TF_KERAS=1 before invoking the converter python API.

Usage

Before running the converter, please notice that tensorflow has to be installed in your python environment, you can choose tensorflow package(CPU version) or tensorflow-gpu(GPU version)

Validate pre-trained Keras application models

It will be useful to convert the models from Keras to ONNX from a python script. You can use the following API:

import keras2onnx
keras2onnx.convert_keras(model, name=None, doc_string='', target_opset=None, channel_first_inputs=None):
    # type: (keras.Model, str, str, int, []) -> onnx.ModelProto
    """
    :param model: keras model
    :param name: the converted onnx model internal name
    :param doc_string:
    :param target_opset:
    :param channel_first_inputs: A list of channel first input.
    :return:
    """

Use the following script to convert keras application models to onnx, and then perform inference:

import numpy as np
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
import keras2onnx
import onnxruntime

# image preprocessing
img_path = 'elephant.jpg'   # make sure the image is in img_path
img_size = 224
img = image.load_img(img_path, target_size=(img_size, img_size))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# load keras model
from keras.applications.resnet50 import ResNet50
model = ResNet50(include_top=True, weights='imagenet')

# convert to onnx model
onnx_model = keras2onnx.convert_keras(model, model.name)

# runtime prediction
content = onnx_model.SerializeToString()
sess = onnxruntime.InferenceSession(content)
x = x if isinstance(x, list) else [x]
feed = dict([(input.name, x[n]) for n, input in enumerate(sess.get_inputs())])
pred_onnx = sess.run(None, feed)

An alternative way to load onnx model to runtime session is to save the model first:

import onnx
temp_model_file = 'model.onnx'
onnx.save_model(onnx_model, temp_model_file)
sess = onnxruntime.InferenceSession(temp_model_file)

We converted successfully for all the keras application models such as: Xception, VGG16, VGG19, ResNet50, InceptionV3, InceptionResNetV2, MobileNet, MobileNetV2, DenseNet121, DenseNet169, DenseNet201, NASNetMobile, and NASNetLarge. Try the following models and convert them to onnx using the code above.

from keras.applications.xception import Xception
model = Xception(include_top=True, weights='imagenet')

from keras.applications.vgg16 import VGG16
model = VGG16(include_top=True, weights='imagenet')

from keras.applications.vgg19 import VGG19
model = VGG19(include_top=True, weights='imagenet')

from keras.applications.resnet50 import ResNet50
model = ResNet50(include_top=True, weights='imagenet')

from keras.applications.inception_v3 import InceptionV3
model = InceptionV3(include_top=True, weights='imagenet')

from keras.applications.inception_resnet_v2 import InceptionResNetV2
model = InceptionResNetV2(include_top=True, weights='imagenet')

from keras.applications import mobilenet
model = mobilenet.MobileNet(weights='imagenet')

from keras.applications import mobilenet_v2
model = mobilenet_v2.MobileNetV2(weights='imagenet')

from keras.applications.densenet import DenseNet121
model = DenseNet121(include_top=True, weights='imagenet')

from keras.applications.densenet import DenseNet169
model = DenseNet169(include_top=True, weights='imagenet')

from keras.applications.densenet import DenseNet201
model = DenseNet201(include_top=True, weights='imagenet')

from keras.applications.nasnet import NASNetMobile
model = NASNetMobile(include_top=True, weights='imagenet')

from keras.applications.nasnet import NASNetLarge
model = NASNetLarge(include_top=True, weights='imagenet')

Contribute

We welcome contributions in the form of feedback, ideas, or code.

License

MIT License

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.