diff options
-rw-r--r-- | .pre-commit-config.yaml | 2 | ||||
-rw-r--r-- | ethosu/vela/test/test_tflite_reader.py | 36 | ||||
-rw-r--r-- | ethosu/vela/tflite_reader.py | 33 |
3 files changed, 53 insertions, 18 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 704c15d8..8b930b79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: - id: pytest-cov name: pytest - stages: [commit] + stages: [push] language: system entry: pytest -v --cov=ethosu --cov-fail-under=0 types: [python] diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py new file mode 100644 index 00000000..898e3840 --- /dev/null +++ b/ethosu/vela/test/test_tflite_reader.py @@ -0,0 +1,36 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Contains unit tests for tflite_reader +import pytest +from ethosu.vela.tflite_reader import TFLiteSubgraph + + +class TestTFLiteSubgraph: + + # Generate some data for testing len1_array_to_scalar + len1_testdata = [ + (0, None), + pytest.param(1, None, marks=pytest.mark.xfail), + ([1, 2, 3], [1, 2, 3]), + ([10], 10), + ([], []), + ] + + @pytest.mark.parametrize("test_input,expected", len1_testdata) + def test_len1_array_to_scalar(self, test_input, expected): + output = TFLiteSubgraph.len1_array_to_scalar(test_input) + assert output == expected diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 4f9bd7d0..7e158aac 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -91,28 +91,15 @@ class TFLiteSubgraph: shape = list(np_shape) if type(np_shape) is np.ndarray else [] name = decode_str(tens_data.Name()) dtype = datatype_map[tens_data.Type()] - tens = Tensor(shape, dtype, name) - quant = tens_data.Quantization() - def len1_array_to_scalar(arr): - # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in - # the input buffer. This is represented in Vela by using None. - # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays - # are converted to scalars - if isinstance(arr, int) and arr == 0: - return None - if len(arr) == 1: - return arr[0] - return arr - tens.quantization = QuantizationParameters() if quant is not None: - tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy()) - tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy()) - tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy()) - tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy()) + tens.quantization.min = self.len1_array_to_scalar(quant.MinAsNumpy()) + tens.quantization.max = self.len1_array_to_scalar(quant.MaxAsNumpy()) + tens.quantization.scale_f32 = self.len1_array_to_scalar(quant.ScaleAsNumpy()) + tens.quantization.zero_point = self.len1_array_to_scalar(quant.ZeroPointAsNumpy()) if dtype == DataType.uint8: tens.quantization.quant_min = 0 @@ -199,6 +186,18 @@ class TFLiteSubgraph: op.outputs[0] = intermediate_tens act_op.inputs = [intermediate_tens] + @staticmethod + def len1_array_to_scalar(arr): + # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in + # the input buffer. This is represented in Vela by using None. + # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays + # are converted to scalars + if isinstance(arr, int) and arr == 0: + return None + if len(arr) == 1: + return arr[0] + return arr + class TFLiteGraph: def __init__( |