Code Monkey home page Code Monkey logo

axon's Introduction

Axon

Package Documentation

Nx-powered Neural Networks for Elixir.

Axon consists of the following components:

  • Functional API – A low-level API of numerical definitions (defn) of which all other APIs build on.
  • Model Creation API – A high-level model creation API which manages model initialization and application.
  • Training API – An API for quickly training models, inspired by PyTorch Ignite.

Axon provides abstractions that enable easy integration while maintaining a level of separation between each component. You should be able to use any of the APIs without dependencies on others. By decoupling the APIs, Axon gives you full control over each aspect of creating and training a neural network. Axon uses Polaris for its optimization API.

Overview

For an in-depth overview, see: Axon: Deep Learning in Elixir

Functional API

At the lowest-level, Axon consists of a number of modules with functional implementations of common methods in deep learning:

  • Axon.Activations – Element-wise activation functions.
  • Axon.Initializers – Model parameter initialization functions.
  • Axon.Layers – Common deep learning layer implementations.
  • Axon.Losses – Common loss functions.
  • Axon.Metrics – Training metrics such as accuracy, absolute error, precision, etc.

All of the methods in the functional API are implemented as numerical definitions (defn). That means you can use any Nx compiler or backend to accelerate Axon. Additionally, you can arbitrarily compose methods in the Axon functional API with your own numerical definitions. Axon works entirely on Nx tensors, so any library built on top of Nx is likely to integrate well with Axon.

Because Axon’s high-level APIs build on top of the functional API, the same benefits apply. Every neural network can be JIT or AOT compiled using any Nx compiler or backend, or even transformed into high-level neural network formats like TensorFlow Lite and ONNX.

Model Creation

An example model looks something like:

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128)
  |> Axon.dense(10, activation: :softmax)

The model is just an Elixir struct, so serializing it to multiple formats in the future is straightforward. The default inspect protocol provides a simple summary of the model. You can visualize a better summary using the Axon.Display module. For example, you can use Axon.Display.as_table/2 to see a table summary of the model:

+-----------------------------------------------------------------------------------------------------------+
|                                                   Model                                                   |
+==================================+=============+==============+===================+=======================+
| Layer                            | Input Shape | Output Shape | Options           | Parameters            |
+==================================+=============+==============+===================+=======================+
| input ( input )                  | []          | {1, 784}     | shape: {nil, 784} |                       |
|                                  |             |              | optional: false   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_0 ( dense["input"] )       | [{1, 784}]  | {1, 128}     |                   | kernel: f32[784][128] |
|                                  |             |              |                   | bias: f32[128]        |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| dense_1 ( dense["dense_0"] )     | [{1, 128}]  | {1, 10}      |                   | kernel: f32[128][10]  |
|                                  |             |              |                   | bias: f32[10]         |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
| softmax_0 ( softmax["dense_1"] ) | [{1, 10}]   | {1, 10}      |                   |                       |
+----------------------------------+-------------+--------------+-------------------+-----------------------+
Total Parameters: 101770
Total Parameters Memory: 407080 bytes

Axon provides a few conveniences for working with models. First, we chose to take the philosophy that a model’s only concerns are initialization and application. That means the model shouldn’t be concerned at all with details like training. Axon provides the Axon.build/2 function for building the Axon data structure into initialization and prediction functions:

model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

{init_fn, predict_fn} = Axon.build(model, compiler: EXLA)

params = init_fn.(Nx.template({1, 784}, :f32), %{})
predict_fn.(params, input)

You can pass functions directly to defn, meaning you can easily integrate model execution with existing numerical definitions.

Axon currently has support for the same high-level layers you'd find in a framework like PyTorch or TensorFlow Keras. Our goal is to maintain an API that is productive, extensible, and on par with other modern deep learning frameworks. If there is functionality you need to see that’s not included, feel free to open an issue.

Optimization and training

The purpose of the training API is to provide conveniences and common routines for implementing training loops. The API is inspired by the excellent PyTorch Ignite library.

The general pattern for training a model is:

  1. Define model
  2. Define loop using one of the factory methods (here Axon.Loop.trainer/3)
  3. Instrument loop with metrics and event handlers
  4. Run loop on data
model =
  Axon.input("input", shape: {nil, 784})
  |> Axon.dense(128)
  |> Axon.dense(10, activation: :softmax)

model_state =
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, Polaris.Optimizers.adamw(0.005))
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.handle(:iteration_completed, &log_metrics/1, every: 50)
  |> Axon.Loop.run(data, %{}, epochs: 10, compiler: EXLA)

Axon uses Polaris for its optimization API. It’s important to note that optimization API does not directly depend on Axon models. You can use the API to optimize any differentiable objective function.

In the future we plan to support distributed training loops. We are also seeking ways to improve the performance of our training loops by running them entirely on native accelerators.

Installation

In order to use Axon, you will need Elixir installed. Then create an Elixir project via the mix build tool:

$ mix new my_app

Then add Axon to your dependencies:

def deps do
  [
    {:axon, "~> 0.6"}
  ]
end

You'll also likely want to include an Nx compiler such as EXLA for any practical deep learning workload:

def deps do
  [
    {:axon, "~> 0.6"},
    {:exla, "~> 0.6"},
  ]
end

Integration with other platforms

See Ortex which provides full-blown compatibility for running ONNX models via ONNX Runtime bindings. Alternatively, see AxonONNX to convert ONNX models to Axon models whenever possible to achieve better integration with Nx.

Sponsors

DockYard

License

Copyright (c) 2021 Sean Moriarity

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

axon's People

Contributors

alisinabh avatar arpieb avatar cigrainger avatar floatn avatar grzuy avatar hanrelan avatar ian-gl avatar joaogui1 avatar joelpaulkoch avatar jonatanklosko avatar josevalim avatar kianmeng avatar matiasgali avatar meanderingstream avatar msluszniak avatar nduatik avatar nickgnd avatar nskins avatar ntodd avatar polvalente avatar preciz avatar ricardosantos-99 avatar robinmonjo avatar rubysolo avatar seanmor5 avatar t-rutten avatar tcoyze avatar vans163 avatar wtedw avatar zeionara avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

axon's Issues

Support logging in training API

Somewhat related to #21

PyTorch Lightning supports building custom loggers for integration with third-party logging tools (like TensorBoard). We should include a similar API so training can be monitored in tools like TensorBoard.

Support custom layers

Modularity of layers is easy because we can use regular Elixir functions, but we need a solution for specifying and using custom trainable parameters in layers.

Support for non-DL models and unsupervised learning?

Is Axon only going to be focused on DL approaches to machine learning, or should it also include non-DL supervised learning approaches that leverage labeled datasets and could leverage Nx like SVM, decision trees, random forests, ensembles, etc?

Along the same lines, what about unsupervised learning approaches?

I guess what it comes down to, is Axon being strictly positioned against DL frameworks like TF, Keras, PyTorch etc or do we want it to encompass other statistical ML approaches like Scikit-learn or Shogun ML?

Re-use subgraphs in compilation of combinators

Currently combinators like add, concatenate, etc. traverse back up entire subgraphs and treat them as completely different parts of the computation graph. For example:

x
|> dense(128)
|> add(x)

x entire subgraph will appear in the resulting expression twice, even though it's actually the same thing. This will lead to extremely large expressions in complex models and other possible issues. Additionally, for #28, we would end up returning multiple independent graphs, even if the base of the model shares the same subgraph.

Model inference mode

Similar API for PyTorch: https://pytorch.org/docs/1.9.0/generated/torch.inference_mode.html?highlight=inference%20mode#torch.inference_mode

Keras inference mode is a property of the model (training=False or trainable=False I can't remember).

Unfortunately, simply doing:

{init_fn, predict_fn} = Axon.compile(model)

is not really good enough. There are important differences in behaviors of certain layers in inference mode including:

  • Dropout is disabled
  • BatchNorm uses the EMA of mean/variance calculated over the course of training

We may want to adjust mixed precision in inference mode as well.

It'd be nice to have an API that both "freezes" the parameters into the model, drops training-only behavior in the forward pass, and maybe performs some slight optimizations to the forward expression. One option is to introduce an option to compile called :mode which specifies inference versus training mode - although I'm not sold on it yet.

One thing to note on "freezing" parameters or inlining them into the forward pass, this pattern should probably discouraged or at the very least noted as possibly harmful:

{init_fn, predict_fn} = Axon.compile(model)

params = init_fn.()
inference_fn = &predict_fn.(params, &1)

While you now have a function with parameters inlined as constants, if you can't guarantee the shape of your inputs is consistent (e.g. the batch size is the same), subsequent calls to inference_fn with different shapes or types will load a new potentially very large executable on to your device, quickly leading to OOM. An option around this is to abstract the parameters away by placing them on the device somewhere and holding a reference to them in some global state.

Create mechanism for easy model composition

For now, we'll only consider how this should work in the model creation and execution API, but it will touch the training API as well.

Consider the models in a basic GAN:

generator =
  Axon.input({nil, 100})
  |> Axon.dense(128, activation: :tanh)
  |> Axon.dense(512, activation: :tanh)
  |> Axon.dense(784, activation: :tanh)
  |> Axon.reshape({1, 28, 28})

discriminator =
  Axon.input({nil, 1, 28, 28})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)

In order to train, what you'd want to do is something like:

combined = compose(discriminator, generator)  # represents D(G(input)) 
step_d = Axon.Training.step(discriminator, :binary_cross_entropy, Axon.Optimizers.sgd(0.005)
step_g = Axon.Training.step(combined, :binary_cross_entropy, Axon.Optimizers.adam(0.01)

And then you can alternate using step_d and step_g to train on valid / fake images. Unfortunately, we currently don't support model composition in this sense - you can define functions generator and discriminator without an input block, but there's no way to cleanly determine which parameters belong to which model. Ideally, you'd be able to compose models in some way so that when you initialize, predict, train, etc. parameters are grouped:

combined = compose(discriminator, generator)
{d_params, g_params} = combined_params = Axon.init(combined)
Axon.predict(combined, combined_params)

{{d_params, g_params}, _} =
  combined
  |> Axon.Training.step(:binary_cross_entropy, Axon.Optimizers.adam(0.01)
  |> Axon.Training.train(inputs, targets)

Whatever the implementation is, it will involve adding some metadata to parameters to express that expresses their ownership to a given model. From an API perspective, one option is to introduce Axon.compose for composing Axon structs into a single model while preserving parameter information, although I'm not sure I love that right now.

Issues running mnist examples.

Trying to evaluate examples/mnist.exs or notebooks/mnist.livemd fails on last (training) step with error.

Environment was just installed:

MacBook Pro (16-inch, 2019) Intel based.

Erlang/OTP 24 [erts-12.0.2] [source] [64-bit] [smp:16:16] [ds:16:16:10] [async-threads:1] [jit] [dtrace]
Elixir 1.12.1 (compiled with Erlang/OTP 24)
Livebook 0.1.2

examples/mnist.exs output:

--------------------------------------------------
                      Model
==================================================
 Layer                    Shape        Parameters
==================================================
 input_7 ( input )        {nil, 784}   0
 dense_10 ( dense )       {nil, 128}   100480
 relu_11 ( relu )         {nil, 128}   0
 dropout_12 ( dropout )   {nil, 128}   0
 dense_15 ( dense )       {nil, 10}    1290
 softmax_16 ( softmax )   {nil, 10}    0
--------------------------------------------------


15:04:33.745 [info]  XLA service 0x7f855952ee20 initialized for platform Host (this does not guarantee that XLA will be used). Devices:

15:04:33.745 [info]    StreamExecutor device (0): Host, Default Version
** (ArgumentError) expected a %Nx.Tensor{} or a number, got: {32, 128}
    (nx 0.1.0-dev) lib/nx.ex:1170: Nx.to_tensor/1
    (nx 0.1.0-dev) lib/nx.ex:2561: Nx.element_wise_pred_op/3
    (axon 0.1.0-dev) lib/axon/layers.ex:1098: anonymous fn/1 in Axon.Layers."__defn:dropout__"/2
    (axon 0.1.0-dev) lib/axon/layers.ex:1095: Axon.Layers."__defn:dropout__"/2
    (axon 0.1.0-dev) lib/axon/compiler.ex:218: anonymous fn/7 in Axon.Compiler.recur_predict_fun/3
    (axon 0.1.0-dev) lib/axon/compiler.ex:196: anonymous fn/4 in Axon.Compiler.recur_predict_fun/3
    (axon 0.1.0-dev) lib/axon/training.ex:133: anonymous fn/6 in Axon.Training.step/4
    (nx 0.1.0-dev) lib/nx/defn/grad.ex:15: Nx.Defn.Grad.transform/3

notebooks/mnist.livemd output:

** (ArgumentError) expected a %Nx.Tensor{} or a number, got: {32, 10}
    (nx 0.1.0-dev) lib/nx.ex:1170: Nx.to_tensor/1
    (nx 0.1.0-dev) lib/nx.ex:2561: Nx.element_wise_pred_op/3
    (axon 0.1.0-dev) lib/axon/shared.ex:21: anonymous fn/1 in Axon.Shared."__defn:assert_shape!__"/2
    (nx 0.1.0-dev) lib/nx/defn/compiler.ex:307: Nx.Defn.Compiler.__remote__/4
    (axon 0.1.0-dev) lib/axon/losses.ex:164: Axon.Losses."__defn:categorical_cross_entropy__"/3
    (axon 0.1.0-dev) lib/axon/training.ex:134: anonymous fn/6 in Axon.Training.step/4
    (nx 0.1.0-dev) lib/nx/defn/grad.ex:15: Nx.Defn.Grad.transform/3
    (axon 0.1.0-dev) lib/axon/training.ex:80: anonymous fn/7 in Axon.Training.step/3

Validate input shape as part of model compilation

Invalid input shapes can lead to confusing/surprising error messages. Trivial example:

model = Axon.input({nil, 1, 32}) |> Axon.max_pool()
input = Nx.random_uniform({1, 32}) # oops, forgot a dimension

{init_fn, predict_fn} = Axon.compile(model)
predict_fn.(init_fn.(), input)

Results in:

** (ArgumentError) invalid window dimensions, rank of shape (2) does not match rank of window (3)

Which can be confusing. We should raise a clear error if the input shape/rank is incorrect.

Implement a high level optimization API

The CIFAR example is updated and demonstrates the usage of the low-level constructs in updates.ex to create more advanced optimizers. What's required is essentially the same as what we've already implemented with the layer API. We need to construct optimizer combinators that:

  1. Accept a model (and possibly hyperparameters)
  2. Initialize state w.r.t each parameter in the model
  3. Apply updates according to some transformations defined in updates.ex

I propose we follow an approach similar to the one taken for the layer API and have the following macros in an Axon.Optimizer namespace:

  • init(optimizer, model, opts \\ []) - initializes the optimizer with state (e.g. first and second moment of an update)
  • apply_updates(optimizer, optimizer_state, gradients, params) / apply_gradients / step - applies updates and returns new parameters and updated optimizer state

With this approach, we can implement common optimizers as regular Elixir functions and then users can apply them trivially from within defn and def. I think this leaves us with room to play. It may be that in the future optimizers get taken out of Axon and placed in a separate library similar to optax.

An alternative approach is implement optimizers as behaviours, although this falls a bit short because we still need to pattern match on inputs, and then we'd have to implement macros for each "behaviour".

Integrate named tensors

On the high-level API, we can integrate named tensors by specifying the expected names on input:

Axon.input(batch: nil, channels: 3, height: 224, width: 224)

Then we'll need to consider how these are transformed through the network. For other layers, we could consider adding an option to specify output features as keywords:

Axon.input(batch: nil, pixels: 784)
|> Axon.dense(features: 128, activation: :relu)
|> Axon.dense(label: 10, activation: :softmax)

Axon.Training.train gives an error

Having a small issue. I try to run Axon.Training.train and it gives an error

function Nx.dot/4 is undefined or private. Did you mean one of:

  * dot/2
  * dot/6

Could you please help.

Thanks

Remove dimensional suffixes

After some deliberation I have decided that it's best to drop dimensional suffixes to simplify the API. They're not necessary with most tensor compilers (optimal kernels are generated/compiled based on shape anyway) and I believe it makes more sense to settle for a simpler API. So:

conv1d, conv2d, conv3d -> conv
...and so on...

Allow options in activation functions

Some activation functions support options (e.g. LeakyReLU supports an alpha option). As of now there's no way to include this using the high-level API.

Unify dropout layers

As of now we have 6 (soon to be 7) dropout layers:

  • dropout - general dropout
  • spatial_droput1d, spatial_dropout2d, spatial_dropout3d - spatial dropout, which masks across entire input feature channels
  • alpha_dropout - rather than 0 masking, computes a mask that maintains the mean and standard deviation of the input
  • feature_alpha_dropout - pretty much like spatial dropout, but masks with negative selu instead of 0
  • dropblock - computes a mask with contiguous regions across feature channels (e.g. this will mask large chunks of pixels in an image rather than random pixels)

We can extract the following pattern from each of these dropout layers:

  • rng_state - not maintained in our current implementation, discussion for another issue
  • noise_shape - spatial layers compute the noise shape such that masked layers will be implicitly broadcasted across feature channels
  • mask - how to compute the value of the mask
  • shift_and_scale - regular dropout "scales" by (1 / (1 - rate)), alpha dropout shifts and scales to maintain mean/variance

Knowing this, I propose the following generalized dropout method:

defn dropout(input, opts \\ []) do
  opts = keyword!(opts, [:rate, noise_shape: Nx.shape(input), mask_values: 0.0, gamma: 1.0, beta: 0.0])
  mask = Nx.less(Nx.random_uniform(noise_shape), 1 - rate)
  x = Nx.select(mask, input, mask_values)
  scale_and_shift(x, gamma, beta)  
end

Dropblock will more than likely need to be considered separately as it requires some advanced indexing, but this will simplify the layer API and allow users to explore custom dropout layers.

Model inspection does not enforce correct layer ordering

This also has some implications for compiling layers like add which reference entire subgraphs. The current inspection traverses the entire subgraph and layers are displayed out of order for complex models. See examples/resnet.exs for example.

Add model import/export API

Need ability to serialize models to/from external formats. Model serialization is serialization of the actual computation graph. We should also have the ability to save and load model parameters, but I believe part of that discussion needs to happen upstream with a common Nx tensor serialization format. See e.g. elixir-nx/nx#354

Parameters declared out of order initialize incorrectly

This stems from our current (fragile) way of ensuring parameters are in the correct order by generating a unique ID for each parameter like:

System.unique_integer([:positive, :monotonic])

and then sorting on the ID because they are guaranteed to be ordered. Internally this isn't that much of a problem, but with the addition of custom layers, it's possible to run into the following case:

bias = Axon.param("bias", {}, initializer: :zeros)
weight = Axon.param("weight", {}, initializer: :ones)

Axon.layer(x, fn x, w, b -> Nx.add(Nx.multiply(x, w), b) end, output_shape, [weight, bias])

During initialization, weight will be initialized as bias and bias will be initialized as weight. When they are the same shapes as in this instance, I think this will lead to silent bugs. We could document, but I believe our current method is fragile and needs to be refactored altogether.

We need a better way to track parameters at each layer. This is really an issue in the sense that we can't just declare a parameter and use it in an operation like:

def dense(%Axon{output_shape: shape} = x, units) do
  w = Axon.param("weight", {elem(shape, 1), units)
  b  = Axon.param("bias", {1, units})
  Nx.add(Nx.dot(x, w), b)
end

So we need a better way to ensure parameters are initialized and used in the correct places. I believe the best option is to ensure each parameter has a unique name, and then pass params as a map to defn - which would make this an upstream issue.

Unify normalization layers

The current API has 4 normalization layers:

  • Batch Normalization
  • Instance Normalization
  • Group Normalization
  • Layer Normalization

All of these implementations are built on a fundamental formula:

defn normalize(input, mean, variance, gamma, bias, opts \\ []) do
  opts = keyword!(opts, epsilon: 1.0e-6)
  scale =
    variance
    |> Nx.add(opts[:epsilon])
    |> Nx.rsqrt()
    |> Nx.multiply(gamma)

  input
  |> Nx.subtract(mean)
  |> Nx.multiply(scale)
  |> Nx.add(bias)
end

But differ in how the compute the mean and variance across the input:

  • Batch Normalization - calculated for each individual channel across all samples and spatial dimensions.
    • reduction_axes: [:batch, :height, :width, ...]
  • Instance Normalization - calculated for each individual channel for each individual sample across both spatial dimensions.
    • reduction_axes: [:height, :width, ...]
  • Layer Normalization - calculated for each individual sample across all channels and both spatial dimensions.
    • reduction_axes: [:channels, :height, :width, ...]
  • Group Normalization - calculated across groups of channels and both spatial dimensions for the given group size.
    • reduction_axes: [:groups, :height, :width, ...] (after some reshaping to get :groups)

Additionally, some of these layers are stateful (batch/instance norm) and some are stateless (layer/group norm). Stateful normalization layers return the transformed input and a running average mean and variance adjusted with momentum, relying on the state to compute the next iteration of normalization. Stateless normalization layers return just the transformed input.

In order to unify these normalization layers under the lower-level functional API, rather than have individualized functions for each layer we will instead have:

In the layers API:

  • normalize - see above

In a separate module:

  • batch_norm_stats(input, ra_mean, ra_var, opts \\ []) - returns {mean, var}
  • instance_norm_stats(input, ra_mean, ra_var, opts \\ []) - returns {mean, var}
  • group_norm_stats(input, opts \\ []) - returns {mean, var}
  • layer_norm_stats(input, opts \\ []) - returns {mean, var}

In a separate module (probably an updates.ex or something that has gradient/parameter transforms):

  • ema(x, momentum) - returns a scaled x, exponential moving average

I think this limits code reuse and still enables us to easily build these normalization layers into a high level API

Add additional optimizers and updates

Optimizers:

  • lamb
  • yogi
  • noisy_sgd
  • fromage
  • adamw

Updates:

  • scale_by_yogi
  • add_decayed_weights
  • scale_by_trust_ratio
  • add_noise

Also requires changing update functions from update(updates, state) to update(updates, params, state)

Functionality Roadmap

An issue to track some baseline functionality:

Activations

  • celu
  • elu
  • exp
  • gelu
  • hard_tanh
  • hard_sigmoid
  • hard_silu/hard_swish
  • leaky_relu
  • log_sigmoid
  • relu
  • relu6
  • selu
  • sigmoid
  • silu
  • softmax
  • softplus
  • softsign
  • tanh

Initializers

  • glorot_uniform
  • glorot_normal
  • he_normal
  • he_uniform
  • lecun_uniform
  • lecun_normal
  • normal
  • ones
  • orthogonal - requires elixir-nx/nx#174
  • uniform
  • zeros

Loss Functions

  • binary_crossentropy
  • categorical_crossentropy
  • categorical_hinge
  • cosine_similarity - requires elixir-nx/nx#174
  • ctc
  • hinge
  • kl_divergence
  • log_cosh
  • margin_ranking
  • mean_absolute_error
  • mean_squared_error
  • poisson
  • soft_margin

Metrics

  • accuracy
  • mean_squared_error - requires defndelegate
  • mean_absolute_error - requires defndelegate
  • precision
  • recall
  • sensitivity
  • specificty

Optimizers

Optax style transformations:

  • scale
  • scale_by_adam
  • scale_by_rss
  • scale_by_belief
  • scale_by_rms
  • trace
  • clip
  • clip_by_global_norm
  • centralize
  • scale_by_trust_ratio
  • scale_by_schedule
  • scale_by_radam
  • scale_by_stddev

Schedules

  • polynomial_schedule
  • exponential_decay_schedule
  • cosine_decay_schedule
  • constant_schedule

Layers

For now, just functional implementations resembling torch.nn.functional or tf.nn:

Linear Layers

Convolutional Layers

  • conv
  • conv_transpose
  • depthwise_conv
  • separable_conv2d
  • separable_conv3d

Pooling Layers

  • avg_pool
  • max_pool
  • lp_pool
  • adaptive_avg_pool
  • adaptive_max_pool
  • adaptive_lp_pool
  • global_avg_pool
  • global_max_pool
  • global_lp_pool

Normalization Layers

  • batch_norm
  • group_norm
  • instance_norm
  • layer_norm

Dropout Layers

  • dropout
  • alpha_dropout
  • feature_alpha_dropout
  • spatial_dropout

Attention Layers

  • dot_product_attention - requires elixir-nx/nx#182
  • additive_attention - requires repeat/gather on Nx

Visual Layers

  • resize

We can drop off dimensional suffixes in favor of generic implementations too.

Add high-level layers

Missing so far:

  • conv_transpose
  • transpose
  • reshape (convenience to ignore batch dimensions)
  • pad (convenience to ignore batch dimensions)
  • concatenate
  • add
  • subtract
  • multiply

Other layers tracked in #1.

Ability to disable/limit training messages

While training messages are helpful, in some cases it would be nice to be able to either limit the verbosity or disable the output completely.

Use cases:

  • When running in a Docker container using an external IDE like VSCode, the stream of per-batch messages bogs down training run output, in turn bogging down the runtime of the app (e.g. MNIST takes ~2s/epoch running directly within a Docker container, ~20-30s/epoch running as a remote container against GPUs). In this case per-epoch data would be great, per-batch not so much
  • When running something like an RL model where you're more interested in logging the completion status, state and rewards of an agent per traversal than the incremental loss/accuracy/validation, it's often desirable to disable logging completely.

Maybe being able to set reporting level like :per_batch, ':per_epoch or :none?

Add shape / type assertions

We should add shape / type assertions to layers to provide possibly more specific error messages than falling back on what Nx may give

Implement dynamic unrolling of RNNs

We currently unroll the RNN at compile-time rather than compiling RNNs using a loop. Statically unrolling can be more efficient for short sequences at the expense of more memory consumption; however, we will need the ability to dynamically unroll. Requires elixir-nx/nx#122

Add more examples

Willing to accept examples on different datasets and models to demonstrate different parts of the Axon API and to demonstrate Axon's viability in the ecosystem. The TensorFlow guides are a great place to look for different datasets and problems. If you're blocked on any specific issue feel free to comment on the relevant issue with your use case :)

Integrate validation and testing into training API

More than likely this can be almost identical to the PyTorch approach where we define validation_step and test_step and then we can optionally include validation in with training and then testing with a separate Axon.test

Add penalties for parameter regularization

Lower level should be supported with Axon.Regularization or Axon.Penalty module with L1/L2 penalties implemented as defn. Penalties can be added in custom objective functions passed to step or through a high-level interface in step

Add Recurrent Layers

Require solution for managing layers that maintain state / return multiple outputs

  • gru
  • lstm
  • conv_lstm

Introduce a high-level layer API

The next step after #1 is to implement higher-level constructs on top of the lower-level functional implementations. The goal of the higher-level API is to provide abstractions for building neural networks that:

  • limit the overhead of writing a neural network from scratch
  • are easy to understand, especially for beginners
  • are flexible enough for more complicated architectures (ResNets, GANs, etc.)
  • can be represented as composition of Nx functions for low-level JIT/AOT compilation OR can be represented as higher-level constructs for compilation using specific NN compilers (ONNX, TFLite, etc.)

For simplicity, we'll leave the discussion of efficient handling of network state to a later issue. This issue will only focus on the architecture/representation of a network.

Axon Struct

We will introduce an %Axon{} struct that represents a constructed network. For now, the struct will have the following attributes:

  • :input - input to this layer/model/etc., or in the case of a literal input layer some metadata
  • :shape - shape of the layer's parameters, can be inferred from :input, we can also allow input layer shapes to have nil batch dimensions to represent arbitrary sized batches
  • :transformation - how does this layer transform the input, can be an atom (like :dense) which resolves to an already implemented layer, or a numerical definition for arbitrarily complex transformations
  • other - there's a lot of other metadata that should be included in some way: :initializer and initializer options, :activation and activation options, layer specific options, possibly constraints, callbacks, etc.

The struct would be build up with calls to high-level functions in the root Axon namespace. For example, MNIST:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

Then a model would be compiled, perhaps like:

compiled_model = Axon.compile(model, options)

At a minimum Axon.compile initializers parameters and returns a compiled function using whatever backend the user specifies. It can also be abstracted away in some higher-level training logic, but that's discussion for another issue.

Structuring the network in this way makes arbitrary compilation easy, the API is simple and easy to understand, and flexible enough for complex models. For example, a GAN:

generator =
   Axon.input({nil, 128})
   |> Axon.dense(256, activation: :tanh)
   |> Axon.dense(512, activation: :tanh)
   |> Axon.dense(784, activation: :tanh)
   |> Axon.reshape({28, 28})

discriminator =
   Axon.input({nil, 784})
   |> Axon.dense(128, activation: :relu)
   |> Axon.dense(1, activation: :sigmoid)

combined =
  generator
  |> Axon.compose(discriminator)

Layers

High-level layers are tied directly to their functional implementations in Axon.Layers. Some of them have layer specific options which can be passed during layer creation.

Combinators

We will use combinators similar to: https://thinc.ai/docs/api-layers/#combinators to represent more complex relationships between layers. At a minimum we'd have:

  • compose - function composition
  • add - adds layers
  • concat - concats layers
  • residual - residual output
  • parallel/split - something to represent multiple-model outputs

Implement model inspection

For those familiar with Keras model.summary(), it can be useful to see what your compiled model looks like in terms of shape at each layer, number of trainable parameters, layer names, etc. Because a model is just an Axon struct, we can implement something similar through the Inspect protocol. I personally like Keras style model summaries; however, I'm open to other ideas about how to render a model summary during inspection.

Introduce a training API

Given the core components implemented in #1, we can implement an efficient, simple, but flexible training API. I am proposing an API similar to trax.supervised.training under the Axon.Training namespace that represents a general supervised training pipeline for models.

Training Behaviour

We can consider the training loop to take the following inputs:

  • model_state - parameters, discussed in a future issue for state management and model initialization
  • optimizer - encapsulates both optimizer state, and the update step, discussed in a future issue
  • train_objective (note I'm not using Task to avoid confusion with Elixir tasks) - an objective (loss) function parameterized by the input model such that grad(model_state, objective) differentiates the model parameters w.r.t input model
  • eval_objective - metrics for evaluating model performance on validation sets, loss, accuracy, mse, mae, etc. and some associated state for monitoring training proress
  • dataset - inputs and labels
  • options - miscellaneous

and to perform the following algorithm (this is half pseudocode, half Elixir):

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    for {input, target} <- dataset do
      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
    end
    evaluate(model_state, eval_objective)
  end
end

It's common to use metrics as an easy way to monitor training, so we can introduce a metrics object which encapsulates metric state and metric evaluation functions:

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    for {input, target} <- dataset do
      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
      metrics(model_state, train_objective)
    end
    evaluate(model_state, eval_objective)
  end
end

We can further extend this API with before_x and after_x callbacks (writing checkpoints, plotting graphs, etc.):

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]

  for i <- 0..epochs do
    before_epoch(model_state)

    for {input, target} <- dataset do
      before_batch(model_state)

      gradients = grad(model_state, train_objective)
      update(model_state, gradients, optimizer)
      metrics(model_state, train_objective)

      after_batch(model_state)
    end
    evaluate(model_state, eval_objective)

    after_epoch(model_state)
  end
end

For more flexibility, we can extract each train step into a method, this facilitates easier writing of custom training loops:

def train_on_batch(batch, model_state, train_objective, optimizer) do
  before_batch(model_state)

  gradients = grad(model_state, train_objective(model_state, batch))
  update(model_state, gradients, optimizer)
  metrics(train_objective(model_state, batch))

  after_batch(model_state)
end

def train(model_state, optimizer, train_objective, eval_objective, dataset, options) do
  epochs = options[:epochs]
  steps = options[:steps] || :unlimited # until batch is empty

  for i <- 0..epochs do
    before_epoch(model_state)

    for batch <- dataset, until: steps do
      train_on_batch(batch, model_state, train_objective, optimizer, train_objective)
    end
    evaluate(model_state, eval_objective)

    after_epoch(model_state)
  end
end

Given this framework, the training API would have at a minimum the following callbacks:

defmodule Axon.Training do  
  # Runs before each epoch
  @callback before_epoch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs after each epoch
  @callback after_epoch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs before each batch
  @callback before_batch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs after each batch
  @callback after_batch(model_state) :: {:ok, model_state} | {:error, reason}

  # Runs a single train step, this can also be `defn` for working with infeed/outfeed
  @callback train_on_batch(batch, model_state, train_objective, optimizer) :: model_state

  # Runs a training loop to convergence
  @callback train(model_state, optimizer, train_objective, eval_objective, dataset, options) :: {:ok, model_state} | {:error, reason}
end

I left a lot of key pieces out because I believe this motivates discussion about how to best separate concerns between modules to provide both maximum flexibility, as well as ease-of-use.

Objectives

The spirit of autograd is that writing a machine learning model is as simple as defining a differentiable objective function. I believe that is a principle we should stick to, so I've separated the idea of objective into what could be a separate module, behaviour, function, etc. Objectives need to encapsulate both evaluation objectives and training objectives. They need to be capable of supporting parameterization by a model. I think they should also contain information about associated metrics and evaluation criteria that's tracked during training. Objectives could possibly be defined as a behavior with two methods: predict and loss where loss depends on predict and predict represents a model definition. I'm not sure I really like that idea, but objectives definitely deserve a well-thought out discussion in a separate issue.

Optimizers and Updates

From a design standpoint, updates and optimizers should be included separately. However, from a performance standpoint, I think you might want to fuse gradient calculation with updates, but I believe this could be possible by silently wrapping both update and the grad(objective) in another defn somewhere because defn calls are inlined and compiled. Optimizers as separate modules is a pretty common pattern, so I would go for a behaviour here with common implementations built on the primitive updates.ex.

State

There is a lot of state to keep track of in the above example: model state, optimizer state, metric state, evaluation state, etc. I think it makes sense to wrap state into a common API, so stateful parameters can be flexibly handled. Another advantage of implementing this is we can limit assumptions about actual state management solutions in practice. So users can choose to implement their own if they so choose.

Dataset

The above just lists a dataset as containing batches. I would basically try to represent this as a stream that can be consumed. I don't think dataset implementations fall in this library, but I think Axon should enforce some standard for what datasets look like.

Conclusion

I believe this lays out a plan for integrating higher-level APIs moving forward. Obviously this is incredibly general because it inherently requires implementation details from the unincluded aspects listed above. However, I believe starting with a training API to make sense of how to split up the rest of the work makes sense.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.