diff options
Diffstat (limited to 'ethosu/vela/tflite_mapping.py')
-rw-r--r-- | ethosu/vela/tflite_mapping.py | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index d077768c..79521680 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -328,7 +328,6 @@ class OptionsSerializer: self.module = globals()[self.name] self.cls = getattr(self.module, self.name) self.builtin_opt_type = builtin_options_inv_map[self.cls] - self.custom_opt_format = 0 self.members = [] for mem in members: deserialize = identity @@ -347,11 +346,12 @@ class OptionsSerializer: camelcase_mem = underscore_to_camel_case(mem) self.members.append((underscore_mem, camelcase_mem, deserialize, serialize, is_vector)) - def deserialize(self, builtin_data, custom_data): + def deserialize(self, op_data): + builtin_options = op_data.BuiltinOptions() attrs = {} - if builtin_data: + if builtin_options: tfattrs = self.cls() - tfattrs.Init(builtin_data.Bytes, builtin_data.Pos) + tfattrs.Init(builtin_options.Bytes, builtin_options.Pos) for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members: fun = camelcase_mem if is_vector: @@ -376,26 +376,35 @@ class OptionsSerializer: class CustomOptionsSerializer: + CUSTOM_OPTIONS_NPU_OP = [0x01, 0x04, 0x01] # NpuOp=1, FlexbufferFormat.UINT8=4, byte length=1 + CUSTOM_OPTIONS_FORMAT_DEFAULT = 0 + def __init__(self): - self.builtin_opt_type = 0 self.custom_opt_format = 0 - def deserialize(self, builtin_data, custom_data): + def deserialize(self, op_data): attrs = {} - attrs["custom_options"] = custom_data + custom_options = op_data.CustomOptionsAsNumpy() + attrs["custom_options"] = custom_options + attrs["custom_options_format"] = op_data.CustomOptionsFormat() + + if np.array_equal(custom_options, self.CUSTOM_OPTIONS_NPU_OP): + attrs["custom_type"] = "ExistingNpuOp" + return attrs def serialize(self, builder, attrs): - - custom_opts = attrs.get("custom_options", []) - custom_data = [] + custom_type = attrs.get("custom_type", "") + self.custom_opt_format = attrs.get("custom_options_format", self.CUSTOM_OPTIONS_FORMAT_DEFAULT) # Set NPU op custom options for the TensorFlow Lite custom operator - if custom_opts["type"] == "NpuOp": - custom_data = [0x01, 0x04, 0x01] # NpuOp=1, FlexbufferFormat.UINT8=4, byte length=1 + if custom_type == "NpuOp": + custom_options = self.CUSTOM_OPTIONS_NPU_OP + else: + custom_options = attrs.get("custom_options", []) - custom_data_bytes = struct.pack("<{0}B".format(len(custom_data)), *custom_data) - custom_offset = write_byte_vector(builder, custom_data_bytes) + custom_options_bytes = struct.pack("<{0}B".format(len(custom_options)), *custom_options) + custom_offset = write_byte_vector(builder, custom_options_bytes) return None, custom_offset |