diff options
-rw-r--r-- | ethosu/vela/tflite_writer.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index 06026ba5..18905e35 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -176,13 +176,21 @@ class TFLiteSerialiser: ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type) tf_code, opt_serializer = builtin_operator_inv_map[op_type] - if tf_code == BuiltinOperator.CUSTOM: + if op_type == Op.CustomNpuOp: assert ( - op_type == Op.CustomNpuOp + tf_code == BuiltinOperator.CUSTOM ), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators" custom_code_offset = builder.CreateString("ethos-u") - self.operator_code_map[op_type] = (idx, tf_code, opt_serializer) + # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we + # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the + # correct lookup later on + if op_type == Op.Custom: + if op_type not in self.operator_code_map: + self.operator_code_map[op_type] = {} + self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer) + else: + self.operator_code_map[op_type] = (idx, tf_code, opt_serializer) OperatorCode.OperatorCodeStart(builder) OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code) @@ -262,7 +270,10 @@ class TFLiteSerialiser: [self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map] ) - op_idx, tflop, opt_serializer = self.operator_code_map[op.type] + if op.type == Op.Custom: + op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")] + else: + op_idx, tflop, opt_serializer = self.operator_code_map[op.type] builtin_opt_offset = None custom_opt_offset = None |