From 5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 30 Jun 2021 09:07:16 +0200 Subject: MLBEDSW-4840 Move setting of input indices to tflite reader Mapping to internal input indexing has been added to tflite_reader.py and tosa_reader.py. And the other way around in tflite_writer.py. Signed-off-by: Patrik Gustavsson Change-Id: I4d8596e747cfa7c4203884c4e785eb1977e2bcc1 --- ethosu/vela/tflite_writer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'ethosu/vela/tflite_writer.py') diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 8cabb0ac..3701893e 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -24,6 +24,7 @@ from flatbuffers.builder import UOffsetTFlags from .errors import VelaError from .nn_graph import PassPlacement from .operation import Op +from .reader_util import align_inputs_indices from .tensor import MemType from .tensor import TensorPurpose from .tflite import Buffer @@ -38,7 +39,6 @@ from .tflite_mapping import builtin_operator_inv_map from .tflite_mapping import BuiltinOperator from .tflite_mapping import datatype_inv_map - # ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here: tflite_version = 3 @@ -90,6 +90,8 @@ class TFLiteSerialiser: for ps in sg.passes: for op in ps.ops: if op.type not in self.ops_to_ignore: + # swap from nng input indexing to TensorFlow Lite input indexing + self.align_nng_inputs_to_tflite(op) all_ops.append(op) if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op(): # If values are None op has non-constant weights @@ -104,6 +106,11 @@ class TFLiteSerialiser: self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops)) self.operator_code_map = {} + def align_nng_inputs_to_tflite(self, op): + from_indices = op.type.info.indices + _, _, to_indices = builtin_operator_inv_map[op.type] + op.inputs = align_inputs_indices(from_indices, to_indices, op.inputs) + def write_byte_vector(self, v, alignment=1): builder = self.builder builder.StartVector(1, len(v), alignment) @@ -170,13 +177,13 @@ class TFLiteSerialiser: builder = self.builder custom_code_offset = None if op_type == Op.Custom: - tf_code, opt_serializer = builtin_operator_inv_map[op_type] + tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type] custom_code_offset = builder.CreateString(custom_code) else: assert ( op_type in builtin_operator_inv_map ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type) - tf_code, opt_serializer = builtin_operator_inv_map[op_type] + tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type] if op_type == Op.CustomNpuOp: assert ( -- cgit v1.2.1