aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2020-06-17 14:53:11 +0100
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commitc8310b1432f7a77df3c95e8ecf8248c8a953b411 (patch)
treeeaddfe6ae80db3c85ddca824e0fc70739d05a9d5
parent10a6618784aae35de389e0291fd2d78cbfa03bb7 (diff)
downloadethos-u-vela-c8310b1432f7a77df3c95e8ecf8248c8a953b411.tar.gz
MLBEDSW-2528: MLCE-219: Custom operator pass through
- Fixed custom operator pass through - Added error printing functions for operators and tensor - Minor cleanup of custom exception handling Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Idf295df1e4c544381dc480244d880c32fb285e38
-rw-r--r--ethosu/vela/errors.py79
-rw-r--r--ethosu/vela/mark_tensors.py21
-rw-r--r--ethosu/vela/model_reader.py22
-rw-r--r--ethosu/vela/npu_serialisation.py2
-rw-r--r--ethosu/vela/operation.py4
-rw-r--r--ethosu/vela/tensor.py1
-rw-r--r--ethosu/vela/test/test_model_reader.py10
-rw-r--r--ethosu/vela/tflite_mapping.py37
-rw-r--r--ethosu/vela/tflite_reader.py31
-rw-r--r--ethosu/vela/tflite_writer.py35
-rw-r--r--ethosu/vela/weight_compressor.py9
11 files changed, 180 insertions, 71 deletions
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index efe64d5c..2c93fbc6 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -15,6 +15,10 @@
# limitations under the License.
# Description:
# Defines custom exceptions.
+import sys
+
+from .operation import Operation
+from .tensor import Tensor
class VelaError(Exception):
@@ -31,7 +35,7 @@ class InputFileError(VelaError):
"""Raised when reading the input file results in errors"""
def __init__(self, file_name, msg):
- self.data = "Error reading {}: {}".format(file_name, msg)
+ self.data = "Error reading input file {}: {}".format(file_name, msg)
class UnsupportedFeatureError(VelaError):
@@ -45,4 +49,75 @@ class OptionError(VelaError):
"""Raised when an incorrect command line option is used"""
def __init__(self, option, option_value, msg):
- self.data = "Incorrect argument: {} {}: {}".format(option, option_value, msg)
+ self.data = "Incorrect argument to CLI option: {} {}: {}".format(option, option_value, msg)
+
+
+def OperatorError(op, msg):
+ """Called when parsing an operator results in errors"""
+
+ assert isinstance(op, Operation)
+
+ if op.op_index is None:
+ data = "Invalid {} (name = {}) operator in the internal representation.".format(op.type, op.name)
+ else:
+ data = "Invalid {} (op_index = {}) operator in the input network.".format(op.type, op.op_index)
+
+ data += " {}\n".format(msg)
+
+ data += " Input tensors:\n"
+ for idx, tens in enumerate(op.inputs):
+ if isinstance(tens, Tensor):
+ tens_name = tens.name
+ else:
+ tens_name = "Not a Tensor"
+
+ data += " {} = {}\n".format(idx, tens_name)
+
+ data += " Output tensors:\n"
+ for idx, tens in enumerate(op.outputs):
+ if isinstance(tens, Tensor):
+ tens_name = tens.name
+ else:
+ tens_name = "Not a Tensor"
+
+ data += " {} = {}\n".format(idx, tens_name)
+
+ data = data[:-1] # remove last newline
+
+ print("Error: {}".format(data))
+ sys.exit(1)
+
+
+def TensorError(tens, msg):
+ """Called when parsing a tensor results in errors"""
+
+ assert isinstance(tens, Tensor)
+
+ data = "Invalid {} tensor. {}\n".format(tens.name, msg)
+
+ data += " Driving operators:\n"
+ for idx, op in enumerate(tens.ops):
+ if isinstance(op, Operation):
+ op_type = op.type
+ op_id = op.op_index
+ else:
+ op_type = "Not an Operation"
+ op_id = ""
+
+ data += " {} = {} ({})\n".format(idx, op_type, op_id)
+
+ data += " Consuming operators:\n"
+ for idx, op in enumerate(tens.consumer_list):
+ if isinstance(op, Operation):
+ op_type = op.type
+ op_id = op.op_index
+ else:
+ op_type = "Not an Operation"
+ op_id = ""
+
+ data += " {} = {} ({})\n".format(idx, op_type, op_id)
+
+ data = data[:-1] # remove last newline
+
+ print("Error: {}".format(data))
+ sys.exit(1)
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 72ab8cfa..c4f2bae2 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -17,8 +17,10 @@
# Mark purpose and select formats for Tensors. Also compresses the weights.
from . import rewrite_graph
from . import weight_compressor
+from .errors import OperatorError
from .tensor import TensorFormat
from .tensor import TensorPurpose
+from .tflite_mapping import custom_prefix
def purpose_from_list(lst):
@@ -268,18 +270,33 @@ def mark_tensor_purpose(nng, arch, verbose_tensor_purpose=False):
if ops is None or op.type in ops:
if ops is None:
print(
- "warning: don't know how to mark up purpose for",
+ "Warning: Don't know how to mark up purpose for",
op.type,
op.inputs,
"triggering all feature map fallback",
)
+
for idx, tens in enumerate(op.inputs):
purpose = input_purpose(op, idx)
mark_tensor_helper(tens, purpose)
+
if op.type == "Reshape":
# Reshape's input and output point to same data
op.outputs[0].mem_area = op.inputs[0].mem_area
+
+ if op.type.startswith(custom_prefix) and op.attrs.get("custom_type", "") == "ExistingNpuOp":
+ scratch_tensor = None
+
+ if len(op.inputs) >= 3:
+ scratch_tensor = op.inputs[2] # should be existing scratch tensor
+ if scratch_tensor.name.endswith("_scratch"):
+ scratch_tensor.purpose = TensorPurpose.Scratch
+
+ if scratch_tensor is None:
+ raise OperatorError(op, "Scratch tensor not found.")
+
break
+
return op
for sg in nng.subgraphs:
@@ -316,6 +333,8 @@ def mark_tensor_format(nng, arch, verbose_tensor_format=False):
fmt = arch.default_feature_map_format
elif tens.purpose == TensorPurpose.Weights:
fmt = arch.default_weight_format
+ elif tens.purpose == TensorPurpose.Scratch:
+ fmt = arch.default_feature_map_format
elif tens.purpose == TensorPurpose.Unknown:
fmt = TensorFormat.Unknown
else:
diff --git a/ethosu/vela/model_reader.py b/ethosu/vela/model_reader.py
index 6deb2538..0f79f9b2 100644
--- a/ethosu/vela/model_reader.py
+++ b/ethosu/vela/model_reader.py
@@ -17,7 +17,6 @@
# Dispatcher for reading a neural network model.
from . import tflite_reader
from .errors import InputFileError
-from .errors import VelaError
class ModelReaderOptions:
@@ -32,17 +31,12 @@ class ModelReaderOptions:
def read_model(fname, options, feed_dict={}, output_node_names=[], initialisation_nodes=[]):
if fname.endswith(".tflite"):
- try:
- return tflite_reader.read_tflite(
- fname,
- options.batch_size,
- feed_dict=feed_dict,
- output_node_names=output_node_names,
- initialisation_nodes=initialisation_nodes,
- )
- except VelaError as e:
- raise e
- except Exception as e:
- raise InputFileError(fname, str(e))
+ return tflite_reader.read_tflite(
+ fname,
+ options.batch_size,
+ feed_dict=feed_dict,
+ output_node_names=output_node_names,
+ initialisation_nodes=initialisation_nodes,
+ )
else:
- raise InputFileError(fname, "Unknown input file format. Only .tflite files are supported")
+ raise InputFileError(fname, "Unsupported file extension. Only .tflite files are supported")
diff --git a/ethosu/vela/npu_serialisation.py b/ethosu/vela/npu_serialisation.py
index 08dc0d38..18d38f3f 100644
--- a/ethosu/vela/npu_serialisation.py
+++ b/ethosu/vela/npu_serialisation.py
@@ -141,7 +141,7 @@ def rewrite_npu_call_ops(nng, sg, arch):
for op in ps.ops:
if op.type == "NpuOp":
callee = op.attrs["subgraph"]
- op.attrs["custom_options"] = {"type": op.type}
+ op.attrs["custom_type"] = op.type
sz = 0
for tens in [callee.scratch_tensor, callee.flash_tensor, callee.command_stream_tensor]:
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 51311ef7..448d8382 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -31,7 +31,7 @@ class Operation:
"""Class representing a Neural Network operation. Has a name, a type,
input and output tensors, as well as an attribute dictionary."""
- __slots__ = "type", "name", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
+ __slots__ = "type", "name", "op_index", "attrs", "inputs", "outputs", "flops", "scheduled_pass", "run_on_npu"
def __init__(self, op_type, name):
self.type = op_type
@@ -42,6 +42,7 @@ input and output tensors, as well as an attribute dictionary."""
self.flops = 0
self.run_on_npu = True
self.scheduled_pass = None
+ self.op_index = None # input network operator index
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -51,6 +52,7 @@ input and output tensors, as well as an attribute dictionary."""
res.outputs = list(self.outputs)
res.flops = self.flops
res.scheduled_pass = self.scheduled_pass
+ res.op_index = None # not relevant as not part of input network
return res
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 426a710b..42d95262 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -226,7 +226,6 @@ class Tensor:
"weight_compressed_offsets",
"element_size_bytes",
"block_traversal",
- "offset",
"cpu_tensor",
"npu_tensor",
"equivalence_id",
diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py
index ee9a51e8..23e7e90b 100644
--- a/ethosu/vela/test/test_model_reader.py
+++ b/ethosu/vela/test/test_model_reader.py
@@ -26,15 +26,7 @@ def test_read_model_incorrect_extension(tmpdir):
model_reader.read_model("no_tflite_file.txt", model_reader.ModelReaderOptions())
-def test_read_model_corrupt_contents(tmpdir):
- # Tests read_model with a corrupt .tflite file
- fname = tmpdir.join("corrupt.tflite")
- fname.write("abcde1234")
- with pytest.raises(InputFileError):
- model_reader.read_model(fname.strpath, model_reader.ModelReaderOptions())
-
-
def test_read_model_file_not_found(tmpdir):
# Tests read_model with a .tflite file that does not exist
- with pytest.raises(InputFileError):
+ with pytest.raises(FileNotFoundError):
model_reader.read_model("non_existing.tflite", model_reader.ModelReaderOptions())
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index d077768c..79521680 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -328,7 +328,6 @@ class OptionsSerializer:
self.module = globals()[self.name]
self.cls = getattr(self.module, self.name)
self.builtin_opt_type = builtin_options_inv_map[self.cls]
- self.custom_opt_format = 0
self.members = []
for mem in members:
deserialize = identity
@@ -347,11 +346,12 @@ class OptionsSerializer:
camelcase_mem = underscore_to_camel_case(mem)
self.members.append((underscore_mem, camelcase_mem, deserialize, serialize, is_vector))
- def deserialize(self, builtin_data, custom_data):
+ def deserialize(self, op_data):
+ builtin_options = op_data.BuiltinOptions()
attrs = {}
- if builtin_data:
+ if builtin_options:
tfattrs = self.cls()
- tfattrs.Init(builtin_data.Bytes, builtin_data.Pos)
+ tfattrs.Init(builtin_options.Bytes, builtin_options.Pos)
for underscore_mem, camelcase_mem, deserialize, serialize, is_vector in self.members:
fun = camelcase_mem
if is_vector:
@@ -376,26 +376,35 @@ class OptionsSerializer:
class CustomOptionsSerializer:
+ CUSTOM_OPTIONS_NPU_OP = [0x01, 0x04, 0x01] # NpuOp=1, FlexbufferFormat.UINT8=4, byte length=1
+ CUSTOM_OPTIONS_FORMAT_DEFAULT = 0
+
def __init__(self):
- self.builtin_opt_type = 0
self.custom_opt_format = 0
- def deserialize(self, builtin_data, custom_data):
+ def deserialize(self, op_data):
attrs = {}
- attrs["custom_options"] = custom_data
+ custom_options = op_data.CustomOptionsAsNumpy()
+ attrs["custom_options"] = custom_options
+ attrs["custom_options_format"] = op_data.CustomOptionsFormat()
+
+ if np.array_equal(custom_options, self.CUSTOM_OPTIONS_NPU_OP):
+ attrs["custom_type"] = "ExistingNpuOp"
+
return attrs
def serialize(self, builder, attrs):
-
- custom_opts = attrs.get("custom_options", [])
- custom_data = []
+ custom_type = attrs.get("custom_type", "")
+ self.custom_opt_format = attrs.get("custom_options_format", self.CUSTOM_OPTIONS_FORMAT_DEFAULT)
# Set NPU op custom options for the TensorFlow Lite custom operator
- if custom_opts["type"] == "NpuOp":
- custom_data = [0x01, 0x04, 0x01] # NpuOp=1, FlexbufferFormat.UINT8=4, byte length=1
+ if custom_type == "NpuOp":
+ custom_options = self.CUSTOM_OPTIONS_NPU_OP
+ else:
+ custom_options = attrs.get("custom_options", [])
- custom_data_bytes = struct.pack("<{0}B".format(len(custom_data)), *custom_data)
- custom_offset = write_byte_vector(builder, custom_data_bytes)
+ custom_options_bytes = struct.pack("<{0}B".format(len(custom_options)), *custom_options)
+ custom_offset = write_byte_vector(builder, custom_options_bytes)
return None, custom_offset
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 5667aff5..9d312e52 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -20,6 +20,7 @@ import os.path
import numpy as np
from .errors import InputFileError
+from .errors import TensorError
from .nn_graph import Graph
from .nn_graph import Subgraph
from .operation import Operation
@@ -69,14 +70,16 @@ class TFLiteSubgraph:
self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
for idx in range(subgraph.OperatorsLength()):
- self.parse_operator(subgraph.Operators(idx))
+ self.parse_operator(idx, subgraph.Operators(idx))
- self.outputs = [self.tensors[idx] for idx in subgraph.OutputsAsNumpy()]
- self.inputs = [self.tensors[idx] for idx in subgraph.InputsAsNumpy()]
+ self.outputs = self.get_tensors_from_indices_remove_duplicates(subgraph.OutputsAsNumpy(), "output")
+ self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
# Fix up tensors without operations. Generate either Placeholder or Constant ops
for tens in self.inputs:
- assert not tens.ops
+ if tens.ops != []:
+ TensorError(tens, "This subgraph input tensor has unexpected driving operators.")
+
op = Operation("Placeholder", tens.name)
op.outputs = [tens]
tens.ops = [op]
@@ -87,6 +90,21 @@ class TFLiteSubgraph:
op.outputs = [tens]
tens.ops = [op]
+ def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
+ tensors = []
+ for idx in indices:
+ tensor = self.tensors[idx]
+ if tensor not in tensors:
+ tensors.append(tensor)
+ else:
+ print(
+ "Warning: Subgraph {0} tensor ({1}) with idx = {2} already seen. Removing the duplicate.".format(
+ warning_str, tensor, idx
+ )
+ )
+
+ return tensors
+
def parse_tensor(self, tens_data):
np_shape = tens_data.ShapeAsNumpy()
shape = list(np_shape) if type(np_shape) is np.ndarray else []
@@ -121,7 +139,7 @@ class TFLiteSubgraph:
tens.values = tens.quantization.dequantize(tens.quant_values)
return tens
- def parse_operator(self, op_data):
+ def parse_operator(self, op_index, op_data):
op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()]
inputs = [self.tensors[idx] for idx in op_data.InputsAsNumpy()]
outputs = [self.tensors[idx] for idx in op_data.OutputsAsNumpy()]
@@ -129,6 +147,7 @@ class TFLiteSubgraph:
if len(outputs):
name = outputs[0].name
op = Operation(op_type, name)
+ op.op_index = op_index
op.inputs = inputs
op.outputs = outputs
for out in op.outputs:
@@ -143,7 +162,7 @@ class TFLiteSubgraph:
inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
if opt_serializer is not None:
- op.attrs = opt_serializer.deserialize(op_data.BuiltinOptions(), op_data.CustomOptionsAsNumpy())
+ op.attrs = opt_serializer.deserialize(op_data)
if "stride_w" in op.attrs:
op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1)
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 675b6985..8db3e5b8 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -133,10 +133,9 @@ class TFLiteSerialiser:
builder.PrependUOffsetTRelative(e)
return builder.EndVector(len(v))
- def assign_buffers_to_tensors(self, tensors):
- scratch_tensors = [tens for tens in tensors if tens.purpose == TensorPurpose.Scratch]
- if len(scratch_tensors) > 0:
- scratch_tensor_mem_area = scratch_tensors[0].mem_area
+ def assign_buffers_to_tensors(self, tensors, scratch_tensor):
+ if scratch_tensor is not None:
+ scratch_tensor_mem_area = scratch_tensor.mem_area
else:
scratch_tensor_mem_area = None # all tensors are initialised to MemArea.Unknown
@@ -150,7 +149,7 @@ class TFLiteSerialiser:
buffer_map[tens] = buf_idx
buf_idx += 1
- # Initialize buffers_to_write to a length equal to numer of buffers so
+ # Initialize buffers_to_write to a length equal to number of buffers so
# they can be appended at the correct index during tensor serialization
self.buffers_to_write = [None] * (buf_idx)
@@ -176,7 +175,7 @@ class TFLiteSerialiser:
assert code == "NpuOp" # Currently only support serialising NPU operators as a custom op
custom_code_offset = builder.CreateString("ethos-u")
- self.operator_code_map[code] = (idx, tf_code, opt_serializer)
+ self.operator_code_map[code] = (idx, tf_code, opt_serializer)
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -311,19 +310,29 @@ class TFLiteSerialiser:
all_tensors = [tens for nm, idx, tens in sorted((tens.name, idx, tens) for idx, tens in enumerate(tensor_set))]
+ scratch_tensors = [tens for tens in all_tensors if tens.purpose == TensorPurpose.Scratch]
+
+ if len(scratch_tensors) == 0:
+ scratch_tensor = None
+ else:
+ assert len(scratch_tensors) == 1, "Multiple scratch tensors"
+ scratch_tensor = scratch_tensors[0]
+
self.tensor_map = {tens: idx for idx, tens in enumerate(all_tensors)}
- self.buffer_map = self.assign_buffers_to_tensors(all_tensors)
+ self.buffer_map = self.assign_buffers_to_tensors(all_tensors, scratch_tensor)
tensors_offset = self.write_offset_vector([self.serialise_tensor(tens) for tens in all_tensors])
- # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro
- scratch_tensor_idx = [v for k, v in self.tensor_map.items() if k.name.endswith("scratch")]
-
# Make sure the input_tensors haven't been modified
assert all(inp in sg.original_inputs for inp in sg.input_tensors)
- inputs_offset = self.write_int_vector(
- [self.tensor_map[tens] for tens in sg.original_inputs] + scratch_tensor_idx
- )
+ inputs = [self.tensor_map[tens] for tens in sg.original_inputs]
+
+ # Add the Scratch Tensor as input to the NPU subgraph to get it allocated by TensorFlow Lite Micro
+ scratch_tensor_idx = self.tensor_map.get(scratch_tensor, None)
+ if scratch_tensor_idx is not None and scratch_tensor_idx not in inputs:
+ inputs.append(scratch_tensor_idx)
+
+ inputs_offset = self.write_int_vector(inputs)
outputs_offset = self.write_int_vector([self.tensor_map[tens] for tens in sg.output_tensors])
operators_offset = self.write_offset_vector([self.serialise_operator(op) for op in all_ops])
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index c5f4ce10..77220a93 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -381,15 +381,6 @@ def calc_scales_and_pack_biases(tens, arch, oc_quantum, rescale_for_faf=False):
def update_pass_weight_and_scale_tensors(nng, arch):
- def find_npu_usage_of_tensor(tens):
- # TODO: This function is identical to the one in mark_tensors.py. A common version should be used.
- for op in tens.consumers():
- if op.type == "DMA":
- return find_npu_usage_of_tensor(op.outputs[0])
- if "npu_block_type" in op.attrs:
- return op.attrs["npu_block_type"]
- return NpuBlockType.Default
-
for sg in nng.subgraphs:
for ps in sg.passes:
tens = ps.weight_tensor