aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2020-06-17 14:53:11 +0100
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commitc8310b1432f7a77df3c95e8ecf8248c8a953b411 (patch)
treeeaddfe6ae80db3c85ddca824e0fc70739d05a9d5 /ethosu/vela/tflite_writer.py
parent10a6618784aae35de389e0291fd2d78cbfa03bb7 (diff)
downloadethos-u-vela-c8310b1432f7a77df3c95e8ecf8248c8a953b411.tar.gz
MLBEDSW-2528: MLCE-219: Custom operator pass through
- Fixed custom operator pass through - Added error printing functions for operators and tensor - Minor cleanup of custom exception handling Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Idf295df1e4c544381dc480244d880c32fb285e38
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 675b6985..8db3e5b8 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -133,10 +133,9 @@ class TFLiteSerialiser:
builder.PrependUOffsetTRelative(e)
return builder.EndVector(len(v))
- def assign_buffers_to_tensors(self, tensors):
- scratch_tensors = [tens for tens in tensors if tens.purpose == TensorPurpose.Scratch]
- if len(scratch_tensors) > 0:
- scratch_tensor_mem_area = scratch_tensors[0].mem_area
+ def assign_buffers_to_tensors(self, tensors, scratch_tensor):
+ if scratch_tensor is not None:
+ scratch_tensor_mem_area = scratch_tensor.mem_area
else:
scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown
@@ -150,7 +149,7 @@ class TFLiteSerialiser:
buffer_map[tens] = buf_idx
buf_idx += 1
- # Initialize buffers_to_write to a length equal to numer of buffers so
+ # Initialize buffers_to_write to a length equal to number of buffers so
# they can be appended at the correct index during tensor serialization
self.buffers_to_write = [None] * (buf_idx)
@@ -176,7 +175,7 @@ class TFLiteSerialiser:
assert code == "NpuOp" # Currently only support serialising NPU operators as a custom op
custom_code_offset = builder.CreateString("ethos-u")
- self.operator_code_map[code] = (idx, tf_code, opt_serializer)
+ self.operator_code_map[code] = (idx, tf_code, opt_serializer)
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -311,19 +310,29 @@ class TFLiteSerialiser:
all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))]
+ scratch_tensors = [tens for tens in all_tensors if tens.purpose == TensorPurpose.Scratch]
+
+ if len(scratch_tensors) == 0:
+ scratch_tensor = None
+ else:
+ assert len(scratch_tensors) == 1, "Multiple scratch tensors"
+ scratch_tensor = scratch_tensors[0]
+
self.tensor_map = {tens: idx for idx, tens in enumerate(all_tensors)}
- self.buffer_map = self.assign_buffers_to_tensors(all_tensors)
+ self.buffer_map = self.assign_buffers_to_tensors(all_tensors, scratch_tensor)
tensors_offset = self.write_offset_vector([self.serialise_tensor(tens) for tens in all_tensors])
- # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro
- scratch_tensor_idx = [v for k, v in self.tensor_map.items() if k.name.endswith("scratch")]
-
# Make sure the input_tensors haven't been modified
assert all(inp in sg.original_inputs for inp in sg.input_tensors)
- inputs_offset = self.write_int_vector(
- [self.tensor_map[tens] for tens in sg.original_inputs] + scratch_tensor_idx
- )
+ inputs = [self.tensor_map[tens] for tens in sg.original_inputs]
+
+ # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro
+ scratch_tensor_idx = self.tensor_map.get(scratch_tensor, None)
+ if scratch_tensor_idx is not None and scratch_tensor_idx not in inputs:
+ inputs.append(scratch_tensor_idx)
+
+ inputs_offset = self.write_int_vector(inputs)
outputs_offset = self.write_int_vector([self.tensor_map[tens] for tens in sg.output_tensors])
operators_offset = self.write_offset_vector([self.serialise_operator(op) for op in all_ops])