diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2021-02-16 21:59:50 +0100 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2021-02-17 09:18:39 +0100 |
commit | 8d0f4890aa0ceae92a33ebb789701ff644a6fcaa (patch) | |
tree | fcecd50a7cc6375f5f4320b42f4b6c5231b854b1 /ethosu/vela/tflite_writer.py | |
parent | 56b6c711d8faaa6bcbc810e895efa650ddd97e73 (diff) | |
download | ethos-u-vela-8d0f4890aa0ceae92a33ebb789701ff644a6fcaa.tar.gz |
[MLBEDSW-3813] Fix LSTM operator pass through
Fixed pass through of LSTM operator.
Change-Id: I23140c69ab6cdc83f6bb8129256b4cc6a7c5ffac
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r-- | ethosu/vela/tflite_writer.py | 39 |
1 files changed, 25 insertions, 14 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py index e190a746..687b8876 100644 --- a/ethosu/vela/tflite_writer.py +++ b/ethosu/vela/tflite_writer.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -203,6 +203,7 @@ class TFLiteSerialiser: def serialise_quantization_parameters(self, quant): builder = self.builder + qp = None min = None max = None scale = None @@ -217,16 +218,18 @@ class TFLiteSerialiser: if quant.zero_point is not None: zero_point = self.write_long_vector(make_vector(quant.zero_point)) - QuantizationParameters.QuantizationParametersStart(builder) - if min is not None: - QuantizationParameters.QuantizationParametersAddMin(builder, min) - if max is not None: - QuantizationParameters.QuantizationParametersAddMax(builder, max) - if scale is not None: - QuantizationParameters.QuantizationParametersAddScale(builder, scale) - if zero_point is not None: - QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point) - return QuantizationParameters.QuantizationParametersEnd(builder) + QuantizationParameters.QuantizationParametersStart(builder) + if min is not None: + QuantizationParameters.QuantizationParametersAddMin(builder, min) + if max is not None: + QuantizationParameters.QuantizationParametersAddMax(builder, max) + if scale is not None: + QuantizationParameters.QuantizationParametersAddScale(builder, scale) + if zero_point is not None: + QuantizationParameters.QuantizationParametersAddZeroPoint(builder, zero_point) + qp = QuantizationParameters.QuantizationParametersEnd(builder) + + return qp def serialise_tensor(self, tens): builder = self.builder @@ -258,7 +261,9 @@ class TFLiteSerialiser: # Empty buffers should be kept unique for TensorFlow Lite Micro Tensor.TensorAddBuffer(builder, buf_id) Tensor.TensorAddName(builder, name) - Tensor.TensorAddQuantization(builder, quant) + if quant is not None: + Tensor.TensorAddQuantization(builder, quant) + Tensor.TensorAddIsVariable(builder, tens.is_variable) res = Tensor.TensorEnd(builder) return res @@ -266,10 +271,15 @@ class TFLiteSerialiser: def serialise_operator(self, op): builder = self.builder - inputs_offset = self.write_int_vector([self.tensor_map[tens] for tens in op.inputs if tens in self.tensor_map]) + inputs_offset = self.write_int_vector( + [self.tensor_map[tens] if tens in self.tensor_map else -1 for tens in op.inputs] + ) outputs_offset = self.write_int_vector( [self.tensor_map[tens] for tens in op.outputs if tens in self.tensor_map] ) + intermediates_offset = self.write_int_vector( + [self.tensor_map[tens] for tens in op.intermediates if tens in self.tensor_map] + ) if op.type == Op.Custom: op_idx, tflop, opt_serializer = self.operator_code_map[op.type][op.attrs.get("custom_code", "")] @@ -300,6 +310,7 @@ class TFLiteSerialiser: Operator.OperatorAddOpcodeIndex(builder, op_idx) Operator.OperatorAddInputs(builder, inputs_offset) Operator.OperatorAddOutputs(builder, outputs_offset) + Operator.OperatorAddIntermediates(builder, intermediates_offset) if builtin_opt_offset is not None: Operator.OperatorAddBuiltinOptionsType(builder, opt_serializer.builtin_opt_type) @@ -328,7 +339,7 @@ class TFLiteSerialiser: # This allows us to serialise tensors which arent attached to any specific ops, # e.g. due to an empty graph containing no ops for op in all_ops + placeholder_ops: - for tens in op.inputs + op.outputs: + for tens in op.inputs + op.outputs + op.intermediates: if tens is not None: tensor_set.add(tens) |