Code Monkey home page Code Monkey logo

Comments (2)

lgeiger avatar lgeiger commented on May 26, 2024 1

It looks like using @tf.function(experimental_implements="larq.bsign") would require to write a custom MLIR Transform similar to how TF handles LSTMs. From a quick look I can't see any other usage in the TensorFlow code at this point.

tf.compat.v1.lite.OpHint are deprecated, but seem to be still supported by both converters and are the only way to achieve this without adding custom TF ops. However the API is not really nice to expose to users.
For LCE I think this could be implemented like this:

  1. Add an OpHint to lq.math.sign
  2. Subclass the TFLiteConverter and use convert_op_hints_to_stubs to transform the graph_def before (or after) running it through grappler. Unfortunately this will rely on a bunch of private TF methods, but that should get us started.
  3. Remove the BSign TF Op from LCE

from compute-engine.

lgeiger avatar lgeiger commented on May 26, 2024 1

So this can be done with the old converter, by adding a call to tf.compat.v1.lite.experimental.convert_op_hints_to_stubs in the middle of the conversion, though it maybe more fruitful to run custom graph transforms after the TFLite conversion similar to #141:

import netron
import larq as lq
import tensorflow as tf

from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl

from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import (
    run_graph_optimizations as _run_graph_optimizations,
)
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
from tensorflow.python.framework import dtypes as _dtypes


class TFLiteConverter(tf.lite.TFLiteConverter):
    def convert(self):
        """Converts a TensorFlow GraphDef based on instance variables.
        Returns:
            The converted data in serialized format.
        Raises:
            ValueError:
                Multiple concrete functions are specified.
                Input shape is not specified.
                Invalid quantization parameters.
        """
        # TODO(b/130297984): Add support for converting multiple function.
        if len(self._funcs) != 1:
            raise ValueError(
                "This converter can only convert a single "
                "ConcreteFunction. Converting multiple functions is "
                "under development."
            )

        frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
            self._funcs[0], lower_control_flow=False
        )
        input_tensors = [
            tensor for tensor in frozen_func.inputs if tensor.dtype != _dtypes.resource
        ]
        output_tensors = frozen_func.outputs

        # Run a Grappler pass.
        graph_def = frozen_func.graph.as_graph_def()
        graph_def = _run_graph_optimizations(
            graph_def,
            input_tensors,
            output_tensors,
            config=self._grappler_config(),
            graph=frozen_func.graph,
        )
        graph_def = tf.compat.v1.lite.experimental.convert_op_hints_to_stubs(
            graph_def=graph_def
        )

        # Checks dimensions in input tensor.
        for tensor in input_tensors:
            # Note that shape_list might be empty for scalar shapes.
            shape_list = tensor.shape.as_list()
            if None in shape_list[1:]:
                raise ValueError(
                    "None is only supported in the 1st dimension. Tensor '{0}' has "
                    "invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list)
                )
            elif shape_list and shape_list[0] is None:
                # Set the batch size to 1 if undefined.
                shape = tensor.shape.as_list()
                shape[0] = 1
                tensor.set_shape(shape)

        self._validate_quantization()
        self._validate_representative_dataset()
        self._debug_info = _get_debug_info(
            _build_debug_info_func(self._funcs[0].graph), graph_def
        )
        converter_kwargs = self._get_base_converter_args()

        # Converts model.
        result = _toco_convert_impl(
            input_data=graph_def,
            input_tensors=input_tensors,
            output_tensors=output_tensors,
            **converter_kwargs
        )

        if self._is_calibration_quantize():
            result = self._calibrate_quantize_model(
                result,
                constants.FLOAT,
                constants.FLOAT,
                self.experimental_new_quantizer,
            )

        return result


def ste_sign(x):
    op_hint = tf.compat.v1.lite.OpHint("larq.bsign")
    x = op_hint.add_input(x)

    @tf.custom_gradient
    def _call(x):
        def grad(dy):
            zeros = tf.zeros_like(dy)
            mask = tf.math.less_equal(tf.math.abs(x), 1.0)
            return tf.where(mask, dy, zeros)

        return lq.math.sign(x), grad

    output = _call(x)
    return op_hint.add_output(output)


model = tf.keras.Sequential(
    [
        tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        lq.layers.QuantDense(
            64,
            input_quantizer=ste_sign,
            kernel_quantizer=ste_sign,
            kernel_constraint="weight_clip",
        ),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

file_name = "/tmp/quantized_mnist.tflite"
converter = TFLiteConverter.from_keras_model(model)
converter.allow_custom_ops = True
converter.experimental_new_converter = False
tflite_model = converter.convert()

with open(file_name, "wb") as file:
    file.write(tflite_model)
netron.start(file_name)

from compute-engine.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.