aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index 68af4874..f444ee51 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -22,6 +22,7 @@ from flatbuffers import encode
from flatbuffers.builder import UOffsetTFlags
from .nn_graph import PassPlacement
+from .operation import Op
from .tensor import MemType
from .tensor import TensorPurpose
from .tflite import Buffer
@@ -34,7 +35,6 @@ from .tflite import SubGraph
from .tflite import Tensor
from .tflite_mapping import builtin_operator_inv_map
from .tflite_mapping import BuiltinOperator
-from .tflite_mapping import custom_prefix
from .tflite_mapping import datatype_inv_map
# ugh, the python flatbuffer interface is missing a method to add in file identifier. patching it in here:
@@ -77,7 +77,7 @@ class TFLiteSerialiser:
self.scratch_fast_buf_id = 1 # Always assign scratch_fast to buffer 1
self.buffers_to_write = [] # have an empty array there
- self.ops_to_ignore = set(("Const", "Placeholder", "SubgraphInput"))
+ self.ops_to_ignore = set((Op.Const, Op.Placeholder, Op.SubgraphInput))
self.tensors_to_reshape = {}
@@ -89,16 +89,17 @@ class TFLiteSerialiser:
for op in ps.ops:
if op.type not in self.ops_to_ignore:
all_ops.append(op)
- if op.type.startswith("Conv2D") or op.type.startswith("DepthwiseConv2d"):
+ if op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op():
# If values are None op has non-constant weights
if op.inputs[1].values is not None:
self.tensors_to_reshape[op.inputs[1]] = (3, 0, 1, 2)
- if op.type.startswith("FullyConnected"):
+ if op.type == Op.FullyConnected:
# If values are None op has non-constant weights
if op.inputs[1].values is not None:
self.tensors_to_reshape[op.inputs[1]] = (1, 0)
- self.operator_codes = list(sorted(set(op.type for op in all_ops)))
+ # list of tuple(Op, string); the custom code is only used for 3rd party custom operators
+ self.operator_codes = sorted(set((op.type, op.attrs.get("custom_code", "")) for op in all_ops))
self.operator_code_map = {}
def write_byte_vector(self, v, alignment=1):
@@ -163,25 +164,25 @@ class TFLiteSerialiser:
return buffer_map
- def serialise_operator_code(self, idx, code):
+ def serialise_operator_code(self, idx, op_type, custom_code):
builder = self.builder
custom_code_offset = None
- if code.startswith(custom_prefix):
- tf_code, opt_serializer = builtin_operator_inv_map[custom_prefix]
- custom_code_offset = builder.CreateString(code[len(custom_prefix) :])
+ if op_type == Op.Custom:
+ tf_code, opt_serializer = builtin_operator_inv_map[op_type]
+ custom_code_offset = builder.CreateString(custom_code)
else:
assert (
- code in builtin_operator_inv_map
- ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(code)
- tf_code, opt_serializer = builtin_operator_inv_map[code]
+ op_type in builtin_operator_inv_map
+ ), "Vela does not contain a mapping to serialise {} operator to a TensorFlow Lite operator".format(op_type)
+ tf_code, opt_serializer = builtin_operator_inv_map[op_type]
if tf_code == BuiltinOperator.CUSTOM:
assert (
- code == "NpuOp"
+ op_type == Op.CustomNpuOp
), "Vela only supports serialising NpuOp operators as TensorFlow Lite Custom operators"
custom_code_offset = builder.CreateString("ethos-u")
- self.operator_code_map[code] = (idx, tf_code, opt_serializer)
+ self.operator_code_map[op_type] = (idx, tf_code, opt_serializer)
OperatorCode.OperatorCodeStart(builder)
OperatorCode.OperatorCodeAddBuiltinCode(builder, tf_code)
@@ -281,6 +282,8 @@ class TFLiteSerialiser:
attrs["dilation_w_factor"] = attrs["dilation"][2]
if "channel_multiplier" in attrs:
attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if op.activation is not None:
+ attrs["fused_activation_function"] = op.activation
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
@@ -310,7 +313,7 @@ class TFLiteSerialiser:
for op in ps.ops:
if op.type not in self.ops_to_ignore:
all_ops.append(op)
- elif op.type == "Placeholder":
+ elif op.type == Op.Placeholder:
placeholder_ops.append(op)
# Add the tensors from all valid ops, as well as the tensors from placeholder ops
@@ -404,7 +407,7 @@ class TFLiteSerialiser:
def serialise_model(self):
builder = self.builder
operator_code_offset = self.write_offset_vector(
- [self.serialise_operator_code(idx, code) for idx, code in enumerate(self.operator_codes)]
+ [self.serialise_operator_code(idx, optype, code) for idx, (optype, code) in enumerate(self.operator_codes)]
)
description = builder.CreateString("Vela Optimised")