diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-09-30 09:01:52 +0200 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-10-08 16:29:29 +0200 |
commit | aee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch) | |
tree | 495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/tflite_reader.py | |
parent | 36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff) | |
download | ethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz |
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string
- Removed unused operator codes
- Refactored some attributes like npu_block_type, fused_activation_function
- Refactored operator index calculation
- Refactored a number of operator sets
Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 39 |
1 files changed, 19 insertions, 20 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 77cc7963..a03f9ec2 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -23,6 +23,7 @@ from .errors import InputFileError from .errors import TensorError from .nn_graph import Graph from .nn_graph import Subgraph +from .operation import Op from .operation import Operation from .tensor import QuantizationParameters from .tensor import Tensor @@ -53,7 +54,7 @@ def clone_and_reshape_tensor(src_tens, reorder): if tens.quant_values is not None: tens.quant_values = tens.quant_values.transpose(reorder) - op = Operation("Const", tens.name) + op = Operation(Op.Const, tens.name) op.set_output_tensor(tens) return tens @@ -78,12 +79,12 @@ class TFLiteSubgraph: if tens.ops != []: TensorError(tens, "This subgraph input tensor has unexpected driving operators.") - op = Operation("Placeholder", tens.name) + op = Operation(Op.Placeholder, tens.name) op.set_output_tensor(tens) for tens in self.tensors: if not tens.ops: - op = Operation("Const", tens.name) + op = Operation(Op.Const, tens.name) op.set_output_tensor(tens) def get_tensors_from_indices_remove_duplicates(self, indices, warning_str): @@ -136,7 +137,7 @@ class TFLiteSubgraph: return tens def parse_operator(self, op_index, op_data): - op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()] + op_type, opt_serializer, custom_code = self.graph.operator_codes[op_data.OpcodeIndex()] inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()] outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()] name = "unknown_op_name" @@ -149,19 +150,13 @@ class TFLiteSubgraph: for out in op.outputs: out.ops = [op] - if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"): + if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected: if inputs[1].values is not None: - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) - if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type): - # No Bias tensor - inputs.append(None) - if inputs[-1]: - inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,)) - - if op_type.startswith("FullyConnected"): - if inputs[1].values is not None: - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) - if len(inputs) < 3: + if op.type == Op.FullyConnected: + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) + else: + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) + if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]: # No Bias tensor inputs.append(None) if inputs[-1]: @@ -170,11 +165,11 @@ class TFLiteSubgraph: if opt_serializer is not None: op.attrs = opt_serializer.deserialize(op_data) - if op_type == "Reshape" and "new_shape" not in op.attrs: + if op_type == Op.Reshape and "new_shape" not in op.attrs: # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape op.attrs["new_shape"] = outputs[0].shape - if op_type == "Cast": + if op_type == Op.Cast: # Cast op should have "in/out_data_type" attribs add if missing if "in_data_type" not in op.attrs: op.attrs["in_data_type"] = inputs[0].dtype @@ -190,6 +185,9 @@ class TFLiteSubgraph: if "depth_multiplier" in op.attrs: op.attrs["channel_multiplier"] = op.attrs["depth_multiplier"] + op.activation = op.attrs.pop("fused_activation_function", None) + if custom_code is not None: + op.attrs["custom_code"] = custom_code @staticmethod def len1_array_to_scalar(arr): @@ -260,9 +258,10 @@ class TFLiteGraph: msg = "The input file contains operator code {} which is currently not supported".format(c) raise InputFileError(self.name, msg) op_type, ser = builtin_operator_map[c] + custom_code = None if c == BuiltinOperator.CUSTOM: - op_type += decode_str(code.CustomCode()) - return op_type, ser + custom_code = decode_str(code.CustomCode()) + return op_type, ser, custom_code def read_tflite( |