diff options
author | Diego Russo <diego.russo@arm.com> | 2020-04-23 18:14:37 +0100 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | d0eee26bc17ecd237c1b1e86cda78f5f310af391 (patch) | |
tree | 8b4b78d1cc0f01d3686be5459353bdf1b4ea73e8 /ethosu/vela | |
parent | e4e58e15d9916fdcef33f5c43c2f60ef124da6a6 (diff) | |
download | ethos-u-vela-d0eee26bc17ecd237c1b1e86cda78f5f310af391.tar.gz |
Add test for len1_array_to_scalar function
Moved len1_array_to_scalar from a nested function to a staticmethod
of TFLiteSubgraph.
Change-Id: I182f0b70f03070855c1a4478d26644892c1ebb15
Signed-off-by: Diego Russo <diego.russo@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/test/test_tflite_reader.py | 36 | ||||
-rw-r--r-- | ethosu/vela/tflite_reader.py | 33 |
2 files changed, 52 insertions, 17 deletions
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py new file mode 100644 index 00000000..898e3840 --- /dev/null +++ b/ethosu/vela/test/test_tflite_reader.py @@ -0,0 +1,36 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Contains unit tests for tflite_reader +import pytest +from ethosu.vela.tflite_reader import TFLiteSubgraph + + +class TestTFLiteSubgraph: + + # Generate some data for testing len1_array_to_scalar + len1_testdata = [ + (0, None), + pytest.param(1, None, marks=pytest.mark.xfail), + ([1, 2, 3], [1, 2, 3]), + ([10], 10), + ([], []), + ] + + @pytest.mark.parametrize("test_input,expected", len1_testdata) + def test_len1_array_to_scalar(self, test_input, expected): + output = TFLiteSubgraph.len1_array_to_scalar(test_input) + assert output == expected diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 4f9bd7d0..7e158aac 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -91,28 +91,15 @@ class TFLiteSubgraph: shape = list(np_shape) if type(np_shape) is np.ndarray else [] name = decode_str(tens_data.Name()) dtype = datatype_map[tens_data.Type()] - tens = Tensor(shape, dtype, name) - quant = tens_data.Quantization() - def len1_array_to_scalar(arr): - # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in - # the input buffer. This is represented in Vela by using None. - # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays - # are converted to scalars - if isinstance(arr, int) and arr == 0: - return None - if len(arr) == 1: - return arr[0] - return arr - tens.quantization = QuantizationParameters() if quant is not None: - tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy()) - tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy()) - tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy()) - tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy()) + tens.quantization.min = self.len1_array_to_scalar(quant.MinAsNumpy()) + tens.quantization.max = self.len1_array_to_scalar(quant.MaxAsNumpy()) + tens.quantization.scale_f32 = self.len1_array_to_scalar(quant.ScaleAsNumpy()) + tens.quantization.zero_point = self.len1_array_to_scalar(quant.ZeroPointAsNumpy()) if dtype == DataType.uint8: tens.quantization.quant_min = 0 @@ -199,6 +186,18 @@ class TFLiteSubgraph: op.outputs[0] = intermediate_tens act_op.inputs = [intermediate_tens] + @staticmethod + def len1_array_to_scalar(arr): + # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in + # the input buffer. This is represented in Vela by using None. + # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays + # are converted to scalars + if isinstance(arr, int) and arr == 0: + return None + if len(arr) == 1: + return arr[0] + return arr + class TFLiteGraph: def __init__( |