diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-06-30 09:07:16 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-07-09 09:51:44 +0200 |
commit | 5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093 (patch) | |
tree | dce92ab9d8a6ceb261c48353ff7077295efa21da /ethosu/vela/tflite_writer.py | |
parent | 8f1f9aaa58175b17cd2e505bfcdb0e40c955ea72 (diff) | |
download | ethos-u-vela-5e26eda0e0f359b6e22b1f1eeb9344cd15e0f093.tar.gz |
MLBEDSW-4840 Move setting of input indices to tflite reader
Mapping to internal input indexing has been added to
tflite_reader.py and tosa_reader.py.
And the other way around in tflite_writer.py.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I4d8596e747cfa7c4203884c4e785eb1977e2bcc1
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 8cabb0ac..3701893e 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -24,6 +24,7 @@ from flatbuffers.builder import UOffsetTFlags from .errors import VelaError from .nn_graph import PassPlacement from .operation import Op +from .reader_util import align_inputs_indices from .tensor import MemType from .tensor import TensorPurpose from .tflite import Buffer @@ -38,7 +39,6 @@ from .tflite_mapping import builtin_operator_inv_map from .tflite_mapping import BuiltinOperator from .tflite_mapping import datatype_inv_map - # ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here: tflite_version = 3 @@ -90,6 +90,8 @@ class TFLiteSerialiser: for ps in sg.passes: for op in ps.ops: if op.type not in self.ops_to_ignore: + # swap from nng input indexing to TensorFlow Lite input indexing + self.align_nng_inputs_to_tflite(op) all_ops.append(op) if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op(): # If values are None op has non-constant weights @@ -104,6 +106,11 @@ class TFLiteSerialiser: self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops)) self.operator_code_map = {} + def align_nng_inputs_to_tflite(self, op): + from_indices = op.type.info.indices + _, _, to_indices = builtin_operator_inv_map[op.type] + op.inputs = align_inputs_indices(from_indices, to_indices, op.inputs) + def write_byte_vector(self, v, alignment=1): builder = self.builder builder.StartVector(1, len(v), alignment) @@ -170,13 +177,13 @@ class TFLiteSerialiser: builder = self.builder custom_code_offset = None if op_type == Op.Custom: - tf_code, opt_serializer = builtin_operator_inv_map[op_type] + tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type] custom_code_offset = builder.CreateString(custom_code) else: assert ( op_type in builtin_operator_inv_map ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type) - tf_code, opt_serializer = builtin_operator_inv_map[op_type] + tf_code, opt_serializer, _ = builtin_operator_inv_map[op_type] if op_type == Op.CustomNpuOp: assert ( |