aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-30 09:01:52 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-08 16:29:29 +0200
commitaee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch)
tree495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/tflite_reader.py
parent36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff)
downloadethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string - Removed unused operator codes - Refactored some attributes like npu_block_type, fused_activation_function - Refactored operator index calculation - Refactored a number of operator sets Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py39
1 files changed, 19 insertions, 20 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 77cc7963..a03f9ec2 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -23,6 +23,7 @@ from .errors import InputFileError
from .errors import TensorError
from .nn_graph import Graph
from .nn_graph import Subgraph
+from .operation import Op
from .operation import Operation
from .tensor import QuantizationParameters
from .tensor import Tensor
@@ -53,7 +54,7 @@ def clone_and_reshape_tensor(src_tens, reorder):
if tens.quant_values is not None:
tens.quant_values = tens.quant_values.transpose(reorder)
- op = Operation("Const", tens.name)
+ op = Operation(Op.Const, tens.name)
op.set_output_tensor(tens)
return tens
@@ -78,12 +79,12 @@ class TFLiteSubgraph:
if tens.ops != []:
TensorError(tens, "This subgraph input tensor has unexpected driving operators.")
- op = Operation("Placeholder", tens.name)
+ op = Operation(Op.Placeholder, tens.name)
op.set_output_tensor(tens)
for tens in self.tensors:
if not tens.ops:
- op = Operation("Const", tens.name)
+ op = Operation(Op.Const, tens.name)
op.set_output_tensor(tens)
def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
@@ -136,7 +137,7 @@ class TFLiteSubgraph:
return tens
def parse_operator(self, op_index, op_data):
- op_type, opt_serializer = self.graph.operator_codes[op_data.OpcodeIndex()]
+ op_type, opt_serializer, custom_code = self.graph.operator_codes[op_data.OpcodeIndex()]
inputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.InputsAsNumpy()]
outputs = [self.tensors[idx] if idx != -1 else None for idx in op_data.OutputsAsNumpy()]
name = "unknown_op_name"
@@ -149,19 +150,13 @@ class TFLiteSubgraph:
for out in op.outputs:
out.ops = [op]
- if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
+ if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
if inputs[1].values is not None:
- inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
- if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type):
- # No Bias tensor
- inputs.append(None)
- if inputs[-1]:
- inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,))
-
- if op_type.startswith("FullyConnected"):
- if inputs[1].values is not None:
- inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
- if len(inputs) < 3:
+ if op.type == Op.FullyConnected:
+ inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
+ else:
+ inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
+ if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]:
# No Bias tensor
inputs.append(None)
if inputs[-1]:
@@ -170,11 +165,11 @@ class TFLiteSubgraph:
if opt_serializer is not None:
op.attrs = opt_serializer.deserialize(op_data)
- if op_type == "Reshape" and "new_shape" not in op.attrs:
+ if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
op.attrs["new_shape"] = outputs[0].shape
- if op_type == "Cast":
+ if op_type == Op.Cast:
# Cast op should have "in/out_data_type" attribs add if missing
if "in_data_type" not in op.attrs:
op.attrs["in_data_type"] = inputs[0].dtype
@@ -190,6 +185,9 @@ class TFLiteSubgraph:
if "depth_multiplier" in op.attrs:
op.attrs["channel_multiplier"] = op.attrs["depth_multiplier"]
+ op.activation = op.attrs.pop("fused_activation_function", None)
+ if custom_code is not None:
+ op.attrs["custom_code"] = custom_code
@staticmethod
def len1_array_to_scalar(arr):
@@ -260,9 +258,10 @@ class TFLiteGraph:
msg = "The input file contains operator code {} which is currently not supported".format(c)
raise InputFileError(self.name, msg)
op_type, ser = builtin_operator_map[c]
+ custom_code = None
if c == BuiltinOperator.CUSTOM:
- op_type += decode_str(code.CustomCode())
- return op_type, ser
+ custom_code = decode_str(code.CustomCode())
+ return op_type, ser, custom_code
def read_tflite(