Code Monkey home page Code Monkey logo

transformers-android-demo's Introduction

Transformers Android Demo
(Tensorflow Lite & Pytorch Mobile)

Sentiment Classification with ELECTRA-Small

Sentiment classification finetuned on Movie Review Dataset (IMDB English Dataset, NSMC Korean Dataset). Both English and Korean are supported.

Available models:

  1. "Original" TorchScript ELECTRA-Small (53MB)
  2. "Original" TFLite ELECTRA-Small (53MB)
  3. FP16 post-training-quantized TFLite ELECTRA-Small (26MB)
  4. "hybrid" (8-bits precision weights) post-training-quantized TFLite ELECTRA-Small (13MB)

Demo

Most of the assets are from Official Pytorch Android Code. (Tested with Galaxy S10)

๐Ÿ“ฑ APK Download Link ๐Ÿ“ฑ

Build the demo app

Android Studio

Prerequisites

  • If you don't have already, install Android Studio, following the instructions on the website.
  • Android Studio 3.2 or later.
  • Install Android SDK and Android NDK using Android Studio UI.
  • You need an Android device and Android development environment with minimum API 26.
  • The libs directory contains a custom build of TensorFlow Lite with TensorFlow ops built-in, which is used by the app. It results in a bigger binary than the "normal" build but allows compatibility with ELECTRA-Small.

Building

  • Open Android Studio, and from the Welcome screen, select Open an existing Android Studio project.
  • From the Open File or Project window that appears, select the directory where you cloned this repo.
  • You may also need to install various platforms and tools according to error messages.
  • If it asks you to use Instant Run, click Proceed Without Instant Run.

Running

  • You need to have an Android device plugged in with developer options enabled at this point. See here for more details on setting up developer devices.
  • Click Run to run the demo app on your Android device.
Gradle (Command Line)

If Android SDK and Android NDK are already installed you can install this application to the connected android device with:

./gradlew installDebug

Dependencies

1. Android

To convert the original model to tflite format, it has to use select TensorFlow Ops. It results in a bigger binary than the "normal" build but allows compatibility with Transformers architecture.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]

๐Ÿšจ Using the transformers tflite model, you should build aar file by yourself. (Please check this documentation for tflite ops) ๐Ÿšจ

In this app, I used the same aar file provided from huggingface demo app. (The libs directory contains a custom build of aar.)

dependencies {
    implementation 'org.pytorch:pytorch_android:1.5.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.5.0'
    // implementation 'org.tensorflow:tensorflow-lite:2.1.0'
    // implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly'
    implementation(name: 'tensorflow-lite-with-select-tf-ops-0.0.0-nightly', ext: 'aar')
}

2. Python

  • torch==1.5.0
  • transformers==2.9.1
  • tensorflow==2.1.0

๐Ÿšจ Highly recommend to use tensorflow v2.1.0 instead of tensorflow v2.2.0. TF-Lite conversion is not working in tensorflow v2.2.0. (Related Issue) ๐Ÿšจ

Convert the model to TFLite or TorchScript

โ€ป The models are already uploaded on huggingface s3. They will be automatically downloaded during build. If you want to download fp16 or 8bits model, uncomment the line in download.gradle.

๐Ÿšจ TFLite conversion isn't working on CPU environment, working well with GPU environment. ๐Ÿšจ

You should specify the input shape(=max_seq_len) for model conversion.

# torchscript
$ python3 model_converter/{$TASK_NAME}/jit_compile.py --max_seq_len 40
# tflite (default)
$ python3 model_converter/{$TASK_NAME}/tflite_converter.py --max_seq_len 40
# tflite (fp16)
$ python3 model_converter/{$TASK_NAME}/tflite_converter.py --max_seq_len 40 --model fp16
# tflite (8bits)
$ python3 model_converter/{$TASK_NAME}/tflite_converter.py --max_seq_len 40 --model 8bits

More Details

1. Length & Padding

MAX_SEQ_LEN is set as 40 in this app. You may change this one by yourself.

  • You should be cautious about the input shape when you converting the model(--max_seq_len option in python script)
  • Also you need to change the MAX_SEQ_LEN in android source code.
private static final int MAX_SEQ_LEN = 40;
  • In TFLite, dynamic input size is not supported! (Related Issue) So if the input shape doesn't match with max_seq_len, it crashes:( You should pad the input sequence for tflite model.
  • But in torchscipt, even though we specified the input shape when converting the model, variable lengths are also possible. So I didn't pad the sequence for pytorch demo. If you want to pad the sequence for pytorch demo, please change the variable as below.
private static final boolean PAD_TO_MAX_LENGTH = true;

2. FP16 & 8Bits on TFLite

I've already uploaded fp16 and 8bits tflite model on huggingface s3. (English & Korean both)

If you want to use those models, uncomment the line in download.gradle as below. They will be automatically downloaded during gradle build.

task downloadLiteModel {
    def downloadFiles = [
        // "https://s3.amazonaws.com/models.huggingface.co/bert/monologg/koelectra-small-finetuned-sentiment/nsmc_small_fp16.tflite" : "nsmc_small_fp16.tflite",
        // "https://s3.amazonaws.com/models.huggingface.co/bert/monologg/koelectra-small-finetuned-sentiment/nsmc_small_8bits.tflite": "nsmc_small_8bits.tflite",
    ]
}

Also you need to change the MODEL_PATH on Activity.

// 1. fp16
private static final String MODEL_PATH = "imdb_small_fp16.tflite";
// 2. 8bits hybrid
private static final String MODEL_PATH = "imdb_small_8bits.tflite";

3. Slow inference when using TorchScript

At the first time running the inference using torchscript, the inference is quite slow. After the first pass, inference time comes back as normal.

It seems that the first time running the forward might do some preheating work. (Not sure about it...)

Reference

transformers-android-demo's People

Contributors

monologg avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.