aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2021-01-25 21:42:56 +0000
committerTim Hall <tim.hall@arm.com>2021-01-25 21:46:24 +0000
commitb21837638cf6cb015f2a1deb1bb81176e620a306 (patch)
tree38eeba78d3cb9d137cf83fedaa3b9edcc0413518
parent8af061ade39c07cbb26f921c42217a7bfdd1b6ba (diff)
downloadethos-u-vela-b21837638cf6cb015f2a1deb1bb81176e620a306.tar.gz
MLBEDSW-3847: MLCE: Vela not handling multiple custom operators correctly
- Fixed bug with multiple 3rd party custom operators not inserting the correct custom_code. Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I470a964867e60d4d71f01592dd33d4ad1aa2d441
-rw-r--r--ethosu/vela/tflite_writer.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 06026ba..18905e3 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -176,13 +176,21 @@ class TFLiteSerialiser:
), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
tf_code, opt_serializer = builtin_operator_inv_map[op_type]
- if tf_code == BuiltinOperator.CUSTOM:
+ if op_type == Op.CustomNpuOp:
assert (
- op_type == Op.CustomNpuOp
+ tf_code == BuiltinOperator.CUSTOM
), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
custom_code_offset = builder.CreateString("ethos-u")
- self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
+ # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we
+ # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the
+ # correct lookup later on
+ if op_type == Op.Custom:
+ if op_type not in self.operator_code_map:
+ self.operator_code_map[op_type] = {}
+ self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer)
+ else:
+ self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -262,7 +270,10 @@ class TFLiteSerialiser:
[self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map]
)
- op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
+ if op.type == Op.Custom:
+ op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")]
+ else:
+ op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
builtin_opt_offset = None
custom_opt_offset = None