aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-30 09:01:52 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-08 16:29:29 +0200
commitaee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch)
tree495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/tflite_writer.py
parent36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff)
downloadethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string - Removed unused operator codes - Refactored some attributes like npu_block_type, fused_activation_function - Refactored operator index calculation - Refactored a number of operator sets Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 68af4874..f444ee51 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -22,6 +22,7 @@ from flatbuffers import encode
from flatbuffers.builder import UOffsetTFlags
from .nn_graph import PassPlacement
+from .operation import Op
from .tensor import MemType
from .tensor import TensorPurpose
from .tflite import Buffer
@@ -34,7 +35,6 @@ from .tflite import SubGraph
from .tflite import Tensor
from .tflite_mapping import builtin_operator_inv_map
from .tflite_mapping import BuiltinOperator
-from .tflite_mapping import custom_prefix
from .tflite_mapping import datatype_inv_map
# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
@@ -77,7 +77,7 @@ class TFLiteSerialiser:
self.scratch_fast_buf_id = 1 # Always assign scratch_fast to buffer 1
self.buffers_to_write = [] # have an empty array there
- self.ops_to_ignore = set(("Const", "Placeholder", "SubgraphInput"))
+ self.ops_to_ignore = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
self.tensors_to_reshape = {}
@@ -89,16 +89,17 @@ class TFLiteSerialiser:
for op in ps.ops:
if op.type not in self.ops_to_ignore:
all_ops.append(op)
- if op.type.startswith("Conv2D") or op.type.startswith("DepthwiseConv2d"):
+ if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
# If values are None op has non-constant weights
if op.inputs[1].values is not None:
self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2)
- if op.type.startswith("FullyConnected"):
+ if op.type == Op.FullyConnected:
# If values are None op has non-constant weights
if op.inputs[1].values is not None:
self.tensors_to_reshape[op.inputs[1]] = (1, 0)
- self.operator_codes = list(sorted(set(op.type for op in all_ops)))
+ # list of tuple(Op, string); the custom code is only used for 3rd party custom operators
+ self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
self.operator_code_map = {}
def write_byte_vector(self, v, alignment=1):
@@ -163,25 +164,25 @@ class TFLiteSerialiser:
return buffer_map
- def serialise_operator_code(self, idx, code):
+ def serialise_operator_code(self, idx, op_type, custom_code):
builder = self.builder
custom_code_offset = None
- if code.startswith(custom_prefix):
- tf_code, opt_serializer = builtin_operator_inv_map[custom_prefix]
- custom_code_offset = builder.CreateString(code[len(custom_prefix) :])
+ if op_type == Op.Custom:
+ tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+ custom_code_offset = builder.CreateString(custom_code)
else:
assert (
- code in builtin_operator_inv_map
- ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(code)
- tf_code, opt_serializer = builtin_operator_inv_map[code]
+ op_type in builtin_operator_inv_map
+ ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
+ tf_code, opt_serializer = builtin_operator_inv_map[op_type]
if tf_code == BuiltinOperator.CUSTOM:
assert (
- code == "NpuOp"
+ op_type == Op.CustomNpuOp
), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
custom_code_offset = builder.CreateString("ethos-u")
- self.operator_code_map[code] = (idx, tf_code, opt_serializer)
+ self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -281,6 +282,8 @@ class TFLiteSerialiser:
attrs["dilation_w_factor"] = attrs["dilation"][2]
if "channel_multiplier" in attrs:
attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if op.activation is not None:
+ attrs["fused_activation_function"] = op.activation
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
@@ -310,7 +313,7 @@ class TFLiteSerialiser:
for op in ps.ops:
if op.type not in self.ops_to_ignore:
all_ops.append(op)
- elif op.type == "Placeholder":
+ elif op.type == Op.Placeholder:
placeholder_ops.append(op)
# Add the tensors from all valid ops, as well as the tensors from placeholder ops
@@ -404,7 +407,7 @@ class TFLiteSerialiser:
def serialise_model(self):
builder = self.builder
operator_code_offset = self.write_offset_vector(
- [self.serialise_operator_code(idx, code) for idx, code in enumerate(self.operator_codes)]
+ [self.serialise_operator_code(idx, optype, code) for idx, (optype, code) in enumerate(self.operator_codes)]
)
description = builder.CreateString("Vela Optimised")