aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 06026ba5..18905e35 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -176,13 +176,21 @@ class TFLiteSerialiser:
), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
tf_code, opt_serializer = builtin_operator_inv_map[op_type]
- if tf_code == BuiltinOperator.CUSTOM:
+ if op_type == Op.CustomNpuOp:
assert (
- op_type == Op.CustomNpuOp
+ tf_code == BuiltinOperator.CUSTOM
), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
custom_code_offset = builder.CreateString("ethos-u")
- self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
+ # there can be multiple different types of 3rd party custom operators (i.e. non-"ethos-u" ones). therefore we
+ # need to add an extra level of indirection to this particular entry in the operator_code_map to allow for the
+ # correct lookup later on
+ if op_type == Op.Custom:
+ if op_type not in self.operator_code_map:
+ self.operator_code_map[op_type] = {}
+ self.operator_code_map[op_type][custom_code] = (idx, tf_code, opt_serializer)
+ else:
+ self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -262,7 +270,10 @@ class TFLiteSerialiser:
[self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map]
)
- op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
+ if op.type == Op.Custom:
+ op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")]
+ else:
+ op_idx, tflop, opt_serializer = self.operator_code_map[op.type]
builtin_opt_offset = None
custom_opt_offset = None