From b21837638cf6cb015f2a1deb1bb81176e620a306 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Mon, 25 Jan 2021 21:42:56 +0000 Subject: MLBEDSW-3847: MLCE: Vela not handling multiple custom operators correctly - Fixed bug with multiple 3rd party custom operators not inserting the correct custom_code. Signed-off-by: Tim Hall Change-Id: I470a964867e60d4d71f01592dd33d4ad1aa2d441 --- ethosu/vela/tflite_writer.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 06026ba5..18905e35 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -176,13 +176,21 @@ class TFLiteSerialiser: ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type) tf_code, opt_serializer = builtin_operator_inv_map[op_type] - if tf_code == BuiltinOperator.CUSTOM: + if op_type == Op.CustomNpuOp: assert ( - op_type == Op.CustomNpuOp + tf_code == BuiltinOperator.CUSTOM ), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators" custom_code_offset = builder.CreateString("ethos-u") - self.operator_code_map[op_type] = (idx, tf_code, opt_serializer) + # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we + # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the + # correct lookup later on + if op_type == Op.Custom: + if op_type not in self.operator_code_map: + self.operator_code_map[op_type] = {} + self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer) + else: + self.operator_code_map[op_type] = (idx, tf_code, opt_serializer) OperatorCode.OperatorCodeStart(builder) OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code) @@ -262,7 +270,10 @@ class TFLiteSerialiser: [self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map] ) - op_idx, tflop, opt_serializer = self.operator_code_map[op.type] + if op.type == Op.Custom: + op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")] + else: + op_idx, tflop, opt_serializer = self.operator_code_map[op.type] builtin_opt_offset = None custom_opt_offset = None -- cgit v1.2.1