Code Monkey home page Code Monkey logo

test_onnx's People

Contributors

sebapersson avatar

Watchers

 avatar

test_onnx's Issues

Do not use ONNX for PEtab SciML extension

I have created an example network in this repository to experiment with, where I created a simple feed-forward network and exported it to ONNX. After experimenting with ONNX, I believe we should not use the ONNX standard format for the PEtab SciML extension, based on the PROS and CONS below.

PROS

  • Flexible format that can support most architectures
  • A de facto standard format

CONS

  • Lack of tools supporting it
  • Networks must be created in an ONNX exporting tool, like PyTorch.

Support

Julia does not have a Julia -> Flux (the main ML library) importer. See issues on ONNX.jl. Instead, the output is a NNlib trace, which is not easy to convert to a Flux model.

Jax does not have an importer that imports into a standard neural network structure, such as Haiku. The closest I found is this example, which creates an executable function based on the ONNX graph.

PyTorch does not have an official importer. There is a third-party importer, onnx2torch. However, using the output from onnx2torch to build a network in Julia, Jax or SBML is not straightforward.

Keras has an importer, but I was unable to make it work even on a simple feed-forward network.

Overall, there does not seem to be a straightforward tool to extract ML model architecture from an ONNX file. Thus, if we use the ONNX format, we would likely need to write a custom (likely Python) parser that converts an ONNX graph into an intermediate architecture file. Implementations can then use this intermediate architecture file to construct the data-driven model, avoiding the need to write an ONNX importer for each tool we want to support in SciML models. Here I want to highlight that I think it is important that we can get the architecture from the ONNX-file, so we can most efficiently use built in functions in ML packages.

Export

Building a model in the ONNX library is not straightforward. See here. Therefore, the expected user workflow would involve building the model in an established tool like Keras or PyTorch. However, most standard machine learning packages are large dependencies that might be tricky to install, and expecting users to use such dependencies is not ideal.

Conclusion

Using ONNX would likely require coding an ONNX importer, which can be done, e.g. see the graph below on the onnx model for a feed-forward network. To fully leverage the functionality of machine learning packages, such as Equinox or Lux.jl, this importer would need to import an ONNX model into a specific format. This would be a lot of work. The importer could potentially leverage tools like onnx2torch to extract architecture, but having PyTorch as a dependency is not ideal (e.g., in Julia, we try to avoid Python dependencies).

Moreover, the current standard would force users to use tools like PyTorch to build the data-driven model. As the SciML supporting packages would likely use Jax or Julia, this means users would have to switch between many different tools.

Alternative Approach

Allow the user to write the network layer architecture in a format similar to PyTorch. For example, in the network.yaml file, specify layers like:

layers: [
             Linear(in_features=2, out_features=5, bias=true), 
             tanh, 
             Linear(in_features=5, out_features=5, bias=true),
             tanh,
             Linear(in_features=5, out_features=5, bias=true),
             tanh,
             Linear(in_features=5, out_features=2, bias=true)]

This would create a feed-forward network with two hidden layers using tanh activation functions. Based on this way of writing ML model, we can specify which functions and layers are allowed in spec, e.g., tanh, relu, etc. This format would also support more complex architectures, such as convolutional neural networks. An additional benefit is that it would be straightforward for users to code the ML model, and it would be easy for tools to parse the model.

For the parameters table, as with this approach we code ML models layer by layer, for example, parameters in layer 1 could be referred to in the parameters table as networkName_layer1_weight.... Specifically, for each kind of layer (e.g., Linear), we can specify how to set up parameters in the parameters table.

Additional info

The onnx-model for a feedforward network, has a graph that looks like below. It could be parsed, however, custom parser for schemes like this requires a bit of work.

node {
  input: "star"
  input: "onnx::MatMul_24"
  output: "/input/MatMul_output_0"
  name: "/input/MatMul"
  op_type: "MatMul"
}
node {
  input: "input.bias"
  input: "/input/MatMul_output_0"
  output: "/input/Add_output_0"
  name: "/input/Add"
  op_type: "Add"
}
node {
  input: "/input/Add_output_0"
  output: "/Tanh_output_0"
  name: "/Tanh"
  op_type: "Tanh"
}
node {
  input: "/Tanh_output_0"
  input: "onnx::MatMul_25"
  output: "/hidden_1/MatMul_output_0"
  name: "/hidden_1/MatMul"
  op_type: "MatMul"
}
node {
  input: "hidden_1.bias"
  input: "/hidden_1/MatMul_output_0"
  output: "/hidden_1/Add_output_0"
  name: "/hidden_1/Add"
  op_type: "Add"
}
node {
  input: "/hidden_1/Add_output_0"
  output: "/Tanh_1_output_0"
  name: "/Tanh_1"
  op_type: "Tanh"
}
node {
  input: "/Tanh_1_output_0"
  input: "onnx::MatMul_26"
  output: "/hidden_2/MatMul_output_0"
  name: "/hidden_2/MatMul"
  op_type: "MatMul"
}
node {
  input: "hidden_2.bias"
  input: "/hidden_2/MatMul_output_0"
  output: "/hidden_2/Add_output_0"
  name: "/hidden_2/Add"
  op_type: "Add"
}
node {
  input: "/hidden_2/Add_output_0"
  output: "/Tanh_2_output_0"
  name: "/Tanh_2"
  op_type: "Tanh"
}
node {
  input: "/Tanh_2_output_0"
  input: "onnx::MatMul_27"
  output: "/output/MatMul_output_0"
  name: "/output/MatMul"
  op_type: "MatMul"
}
node {
  input: "output.bias"
  input: "/output/MatMul_output_0"
  output: "end"
  name: "/output/Add"
  op_type: "Add"
}
name: "main_graph"
initializer {
  dims: 5
  data_type: 1
  name: "input.bias"
  raw_data: "\341V\\\275\355)\315>K\277\234>#\266\204=\244\371a\275"
}
initializer {
  dims: 5
  data_type: 1
  name: "hidden_1.bias"
  raw_data: "\257z\203>[\207\324\276\341\351\254\276\024\'\262>\233d\372="
}
initializer {
  dims: 5
  data_type: 1
  name: "hidden_2.bias"
  raw_data: "Bn3>\277\356\321\276 \036\016\276\037\230\316=E\305\211>"
}
initializer {
  dims: 2
  data_type: 1
  name: "output.bias"
  raw_data: "XSR>\316@\261>"
}
initializer {
  dims: 2
  dims: 5
  data_type: 1
  name: "onnx::MatMul_24"
  raw_data: "\250\241\262>kd)?\371\357\026\277\372v\037?\256T\202\276\365\030\031\277j\006\247>\325\277\270>@\245\306\276\372L\031\276"
}
initializer {
  dims: 5
  dims: 5
  data_type: 1
  name: "onnx::MatMul_25"
  raw_data: "\023\314\370\275\345\2762=?\213L\276\232v\243>+\\7\276\351\352\337\276\027=g>\266\375\265>\243q\"\276\211~\336\276\006\275,>\222\021\004<\245\256\025>\330\255f\276o,\321>Z/v\274\021\033\270>\242\363\205>\264\226#\275\232\223-\276\214\301\013\276\217\302;\276pa\t\276#\342\231>\225\361\314>"
}
initializer {
  dims: 5
  dims: 5
  data_type: 1
  name: "onnx::MatMul_26"
  raw_data: "\230M\320>\244?{\276\234\364\335\275\254\200\331>,\240t\276\343Z\223=f\255\262=\253:\340>\337\330W=M^z>\216T\217>6f\017\276\215\252y\276Lu!>\240\350~\276\035\245\005\276<a\230\276\352\305\006=e)\273\276\314\210\326>zI\361\274_Fo>b\351\250\276#F\265>\036V\254>"
}
initializer {
  dims: 5
  dims: 2
  data_type: 1
  name: "onnx::MatMul_27"
  raw_data: "\306^B:\016dv\276m\004\026\276z\364\312\2761C\311\276]\321U>i\314A>u\035\242>\000z->\035\251.>"
}
input {
  name: "star"
  type {
    tensor_type {
      elem_type: 1
      shape {
        dim {
          dim_value: 2
        }
      }
    }
  }
}
output {
  name: "end"
  type {
    tensor_type {
      elem_type: 1
      shape {
        dim {
          dim_value: 2
        }
      }
    }
  }
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.