Code Monkey home page Code Monkey logo

onnx-pytorch's Introduction

onnx-pytorch

Build Status

Generates PyTorch code from ONNX.

Installation

  • From PyPI
pip install onnx-pytorch
  • From source
git clone https://github.com/fumihwh/onnx-pytorch.git
cd onnx-pytorch
pip install -r requirements.txt
pip install -e .

Usage

By Command Line

python -m onnx_pytorch.code_gen -h

usage: code_gen.py [-h] [--onnx_model_path ONNX_MODEL_PATH] [--output_dir OUTPUT_DIR] [--overwrite OVERWRITE] [--tensor_inplace TENSOR_INPLACE] [--continue_on_error CONTINUE_ON_ERROR] [--simplify_names SIMPLIFY_NAMES]

optional arguments:
  -h, --help            show this help message and exit
  --onnx_model_path ONNX_MODEL_PATH
                        The onnx model path.
  --output_dir OUTPUT_DIR
                        The output dir
  --overwrite OVERWRITE
                        Should overwrite the output dir.
  --tensor_inplace TENSOR_INPLACE
                        Try best to inplace tensor.
  --continue_on_error CONTINUE_ON_ERROR
                        Continue on error.
  --simplify_names SIMPLIFY_NAMES
                        Use indexing shorten name instead of original name.

By Python

from onnx_pytorch import code_gen
code_gen.gen("/path/to/onnx_model", "/path/to/output_dir")

A model.py file and variables/ folder will be created under output_dir/.

Tutorial

  1. Download resnet18 ONNX model.
wget https://github.com/onnx/models/raw/master/vision/classification/resnet/model/resnet18-v2-7.onnx
  1. Use onnx-pytorch to generate PyTorch code and variables.
from onnx_pytorch import code_gen
code_gen.gen("resnet18-v2-7.onnx", "./")
  1. Test result.
import numpy as np
import onnx
import onnxruntime
import torch
torch.set_printoptions(8)

from model import Model

model = Model()
model.eval()
inp = np.random.randn(1, 3, 224, 224).astype(np.float32)
with torch.no_grad():
  torch_outputs = model(torch.from_numpy(inp))

onnx_model = onnx.load("resnet18-v2-7.onnx")
sess_options = onnxruntime.SessionOptions()
session = onnxruntime.InferenceSession(onnx_model.SerializeToString(),
                                       sess_options)
inputs = {session.get_inputs()[0].name: inp}
ort_outputs = session.run(None, inputs)

print(
    "Comparison result:",
    np.allclose(torch_outputs.detach().numpy(),
                ort_outputs[0],
                atol=1e-5,
                rtol=1e-5))

Test

pytest onnx_pytorch/tests

onnx-pytorch's People

Contributors

fumihwh avatar helion-du-mas-des-bourboux-thales avatar jorgemcgomes avatar maimaixiong avatar rogier-stegeman avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.