aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py39
1 files changed, 25 insertions, 14 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index e190a74..687b887 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -203,6 +203,7 @@ class TFLiteSerialiser:
def serialise_quantization_parameters(self, quant):
builder = self.builder
+ qp = None
min = None
max = None
scale = None
@@ -217,16 +218,18 @@ class TFLiteSerialiser:
if quant.zero_point is not None:
zero_point = self.write_long_vector(make_vector(quant.zero_point))
- QuantizationParameters.QuantizationParametersStart(builder)
- if min is not None:
- QuantizationParameters.QuantizationParametersAddMin(builder, min)
- if max is not None:
- QuantizationParameters.QuantizationParametersAddMax(builder, max)
- if scale is not None:
- QuantizationParameters.QuantizationParametersAddScale(builder, scale)
- if zero_point is not None:
- QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point)
- return QuantizationParameters.QuantizationParametersEnd(builder)
+ QuantizationParameters.QuantizationParametersStart(builder)
+ if min is not None:
+ QuantizationParameters.QuantizationParametersAddMin(builder, min)
+ if max is not None:
+ QuantizationParameters.QuantizationParametersAddMax(builder, max)
+ if scale is not None:
+ QuantizationParameters.QuantizationParametersAddScale(builder, scale)
+ if zero_point is not None:
+ QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point)
+ qp = QuantizationParameters.QuantizationParametersEnd(builder)
+
+ return qp
def serialise_tensor(self, tens):
builder = self.builder
@@ -258,7 +261,9 @@ class TFLiteSerialiser:
# Empty buffers should be kept unique for TensorFlow Lite Micro
Tensor.TensorAddBuffer(builder, buf_id)
Tensor.TensorAddName(builder, name)
- Tensor.TensorAddQuantization(builder, quant)
+ if quant is not None:
+ Tensor.TensorAddQuantization(builder, quant)
+ Tensor.TensorAddIsVariable(builder, tens.is_variable)
res = Tensor.TensorEnd(builder)
return res
@@ -266,10 +271,15 @@ class TFLiteSerialiser:
def serialise_operator(self, op):
builder = self.builder
- inputs_offset = self.write_int_vector([self.tensor_map[tens] for tens in op.inputs if tens in self.tensor_map])
+ inputs_offset = self.write_int_vector(
+ [self.tensor_map[tens] if tens in self.tensor_map else -1 for tens in op.inputs]
+ )
outputs_offset = self.write_int_vector(
[self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map]
)
+ intermediates_offset = self.write_int_vector(
+ [self.tensor_map[tens] for tens in op.intermediates if tens in self.tensor_map]
+ )
if op.type == Op.Custom:
op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")]
@@ -300,6 +310,7 @@ class TFLiteSerialiser:
Operator.OperatorAddOpcodeIndex(builder, op_idx)
Operator.OperatorAddInputs(builder, inputs_offset)
Operator.OperatorAddOutputs(builder, outputs_offset)
+ Operator.OperatorAddIntermediates(builder, intermediates_offset)
if builtin_opt_offset is not None:
Operator.OperatorAddBuiltinOptionsType(builder, opt_serializer.builtin_opt_type)
@@ -328,7 +339,7 @@ class TFLiteSerialiser:
# This allows us to serialise tensors which arent attached to any specific ops,
# e.g. due to an empty graph containing no ops
for op in all_ops + placeholder_ops:
- for tens in op.inputs + op.outputs:
+ for tens in op.inputs + op.outputs + op.intermediates:
if tens is not None:
tensor_set.add(tens)