aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2020-04-23 18:14:37 +0100
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commitd0eee26bc17ecd237c1b1e86cda78f5f310af391 (patch)
tree8b4b78d1cc0f01d3686be5459353bdf1b4ea73e8
parente4e58e15d9916fdcef33f5c43c2f60ef124da6a6 (diff)
downloadethos-u-vela-d0eee26bc17ecd237c1b1e86cda78f5f310af391.tar.gz
Add test for len1_array_to_scalar function
Moved len1_array_to_scalar from a nested function to a staticmethod of TFLiteSubgraph. Change-Id: I182f0b70f03070855c1a4478d26644892c1ebb15 Signed-off-by: Diego Russo <diego.russo@arm.com>
-rw-r--r--.pre-commit-config.yaml2
-rw-r--r--ethosu/vela/test/test_tflite_reader.py36
-rw-r--r--ethosu/vela/tflite_reader.py33
3 files changed, 53 insertions, 18 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 704c15d8..8b930b79 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -31,7 +31,7 @@ repos:
- id: pytest-cov
name: pytest
- stages: [commit]
+ stages: [push]
language: system
entry: pytest -v --cov=ethosu --cov-fail-under=0
types: [python]
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
new file mode 100644
index 00000000..898e3840
--- /dev/null
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -0,0 +1,36 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Contains unit tests for tflite_reader
+import pytest
+from ethosu.vela.tflite_reader import TFLiteSubgraph
+
+
+class TestTFLiteSubgraph:
+
+ # Generate some data for testing len1_array_to_scalar
+ len1_testdata = [
+ (0, None),
+ pytest.param(1, None, marks=pytest.mark.xfail),
+ ([1, 2, 3], [1, 2, 3]),
+ ([10], 10),
+ ([], []),
+ ]
+
+ @pytest.mark.parametrize("test_input,expected", len1_testdata)
+ def test_len1_array_to_scalar(self, test_input, expected):
+ output = TFLiteSubgraph.len1_array_to_scalar(test_input)
+ assert output == expected
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 4f9bd7d0..7e158aac 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -91,28 +91,15 @@ class TFLiteSubgraph:
shape = list(np_shape) if type(np_shape) is np.ndarray else []
name = decode_str(tens_data.Name())
dtype = datatype_map[tens_data.Type()]
-
tens = Tensor(shape, dtype, name)
-
quant = tens_data.Quantization()
- def len1_array_to_scalar(arr):
- # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
- # the input buffer. This is represented in Vela by using None.
- # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
- # are converted to scalars
- if isinstance(arr, int) and arr == 0:
- return None
- if len(arr) == 1:
- return arr[0]
- return arr
-
tens.quantization = QuantizationParameters()
if quant is not None:
- tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy())
- tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy())
- tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy())
- tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy())
+ tens.quantization.min = self.len1_array_to_scalar(quant.MinAsNumpy())
+ tens.quantization.max = self.len1_array_to_scalar(quant.MaxAsNumpy())
+ tens.quantization.scale_f32 = self.len1_array_to_scalar(quant.ScaleAsNumpy())
+ tens.quantization.zero_point = self.len1_array_to_scalar(quant.ZeroPointAsNumpy())
if dtype == DataType.uint8:
tens.quantization.quant_min = 0
@@ -199,6 +186,18 @@ class TFLiteSubgraph:
op.outputs[0] = intermediate_tens
act_op.inputs = [intermediate_tens]
+ @staticmethod
+ def len1_array_to_scalar(arr):
+ # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
+ # the input buffer. This is represented in Vela by using None.
+ # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
+ # are converted to scalars
+ if isinstance(arr, int) and arr == 0:
+ return None
+ if len(arr) == 1:
+ return arr[0]
+ return arr
+
class TFLiteGraph:
def __init__(