Comments (2)
It looks like using @tf.function(experimental_implements="larq.bsign")
would require to write a custom MLIR Transform similar to how TF handles LSTMs. From a quick look I can't see any other usage in the TensorFlow code at this point.
tf.compat.v1.lite.OpHint
are deprecated, but seem to be still supported by both converters and are the only way to achieve this without adding custom TF ops. However the API is not really nice to expose to users.
For LCE I think this could be implemented like this:
- Add an OpHint to
lq.math.sign
- Subclass the
TFLiteConverter
and useconvert_op_hints_to_stubs
to transform thegraph_def
before (or after) running it through grappler. Unfortunately this will rely on a bunch of private TF methods, but that should get us started. - Remove the
BSign
TF Op from LCE
from compute-engine.
So this can be done with the old converter, by adding a call to tf.compat.v1.lite.experimental.convert_op_hints_to_stubs
in the middle of the conversion, though it maybe more fruitful to run custom graph transforms after the TFLite conversion similar to #141:
import netron
import larq as lq
import tensorflow as tf
from tensorflow.lite.python import lite_constants as constants
from tensorflow.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.lite.python.util import build_debug_info_func as _build_debug_info_func
from tensorflow.lite.python.util import get_debug_info as _get_debug_info
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import (
run_graph_optimizations as _run_graph_optimizations,
)
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
from tensorflow.python.framework import dtypes as _dtypes
class TFLiteConverter(tf.lite.TFLiteConverter):
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
Returns:
The converted data in serialized format.
Raises:
ValueError:
Multiple concrete functions are specified.
Input shape is not specified.
Invalid quantization parameters.
"""
# TODO(b/130297984): Add support for converting multiple function.
if len(self._funcs) != 1:
raise ValueError(
"This converter can only convert a single "
"ConcreteFunction. Converting multiple functions is "
"under development."
)
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
self._funcs[0], lower_control_flow=False
)
input_tensors = [
tensor for tensor in frozen_func.inputs if tensor.dtype != _dtypes.resource
]
output_tensors = frozen_func.outputs
# Run a Grappler pass.
graph_def = frozen_func.graph.as_graph_def()
graph_def = _run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config=self._grappler_config(),
graph=frozen_func.graph,
)
graph_def = tf.compat.v1.lite.experimental.convert_op_hints_to_stubs(
graph_def=graph_def
)
# Checks dimensions in input tensor.
for tensor in input_tensors:
# Note that shape_list might be empty for scalar shapes.
shape_list = tensor.shape.as_list()
if None in shape_list[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
"invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list)
)
elif shape_list and shape_list[0] is None:
# Set the batch size to 1 if undefined.
shape = tensor.shape.as_list()
shape[0] = 1
tensor.set_shape(shape)
self._validate_quantization()
self._validate_representative_dataset()
self._debug_info = _get_debug_info(
_build_debug_info_func(self._funcs[0].graph), graph_def
)
converter_kwargs = self._get_base_converter_args()
# Converts model.
result = _toco_convert_impl(
input_data=graph_def,
input_tensors=input_tensors,
output_tensors=output_tensors,
**converter_kwargs
)
if self._is_calibration_quantize():
result = self._calibrate_quantize_model(
result,
constants.FLOAT,
constants.FLOAT,
self.experimental_new_quantizer,
)
return result
def ste_sign(x):
op_hint = tf.compat.v1.lite.OpHint("larq.bsign")
x = op_hint.add_input(x)
@tf.custom_gradient
def _call(x):
def grad(dy):
zeros = tf.zeros_like(dy)
mask = tf.math.less_equal(tf.math.abs(x), 1.0)
return tf.where(mask, dy, zeros)
return lq.math.sign(x), grad
output = _call(x)
return op_hint.add_output(output)
model = tf.keras.Sequential(
[
tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
lq.layers.QuantDense(
64,
input_quantizer=ste_sign,
kernel_quantizer=ste_sign,
kernel_constraint="weight_clip",
),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
file_name = "/tmp/quantized_mnist.tflite"
converter = TFLiteConverter.from_keras_model(model)
converter.allow_custom_ops = True
converter.experimental_new_converter = False
tflite_model = converter.convert()
with open(file_name, "wb") as file:
file.write(tflite_model)
netron.start(file_name)
from compute-engine.
Related Issues (20)
- Upgrade TensorFlow dependency to 2.6 HOT 3
- Automatic release builds for benchmarking binaries are broken HOT 2
- Deployment on Cortex-M HOT 2
- Tensor transform triggers dequantization HOT 6
- Error on import HOT 2
- Select indirect BGEMM kernels - Benchmarking grouped binary convolutions HOT 3
- LCEInterpreter and converter design HOT 1
- core dumped when number of threads is larger than 2 HOT 3
- Benchmarking custom model HOT 3
- Int8 quantization for microcontroller HOT 13
- Failed import 'org.tensorflow.lite.DataType' on Android project HOT 8
- `convert_keras_model()` does not work as expected for BinaryDenseNet37 Dilated and XNORNet HOT 1
- DoReFa quantizer with higher number of MACs/Ops, Grouped convs as custom ops on LCE 0.7.0 HOT 3
- Get Operator-wise Profiling Results HOT 1
- Error while performing benchmarking HOT 44
- Bool input tensor HOT 7
- extra model size induced by non-parameter layer HOT 1
- Fix Android benchmarker build
- Larq Compute Engine seems incompatible with tensorflow-lite-task-vision on Android (using the latest tensorflow lite demo code) HOT 2
- Dorefa model size and behavior with full precision model and ste_sign model HOT 13
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from compute-engine.