aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/test_gen_utils.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-02-23 12:15:03 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2022-03-02 16:45:28 +0000
commit015c3550301fdc6d37606995322e144df0940ba2 (patch)
treef51448044ea8262f36aa964ce4b769406bb2cb7e /verif/frameworks/test_gen_utils.py
parentc0fe04d4105884b61b5eeca4c0a932846a77b6e2 (diff)
downloadreference_model-015c3550301fdc6d37606995322e144df0940ba2.tar.gz
Add framework unit test generation scripts
And fixes in tosa_verif_run_tests: * support for no-color printing * stop double printing of error messages on verbose * differentiate result code pass from results check Change-Id: I26e957013a8d18f7d3d3691067dfb778008a1eea Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'verif/frameworks/test_gen_utils.py')
-rw-r--r--verif/frameworks/test_gen_utils.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/verif/frameworks/test_gen_utils.py b/verif/frameworks/test_gen_utils.py
new file mode 100644
index 0000000..2d8e5d6
--- /dev/null
+++ b/verif/frameworks/test_gen_utils.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from enum import IntEnum
+from enum import unique
+
+import tensorflow as tf
+
+
+# Get a string name for a given shape
+def get_shape_str(shape, dtype):
+ shape_name = None
+ for dim in shape:
+ shape_name = (shape_name + "x" + str(dim)) if shape_name else str(dim)
+
+ if dtype == tf.float32:
+ shape_name = shape_name + "_f32"
+ elif dtype == tf.float16:
+ shape_name = shape_name + "_f16"
+ elif dtype == tf.int32:
+ shape_name = shape_name + "_i32"
+ elif dtype == tf.uint32:
+ shape_name = shape_name + "_u32"
+ elif dtype == tf.bool:
+ shape_name = shape_name + "_bool"
+ elif dtype == tf.quint8:
+ shape_name = shape_name + "_qu8"
+ elif dtype == tf.qint8:
+ shape_name = shape_name + "_qi8"
+ elif dtype == tf.qint16:
+ shape_name = shape_name + "_qi16"
+ elif dtype == tf.quint16:
+ shape_name = shape_name + "_qu16"
+ else:
+ raise Exception("Unsupported type: {}".format(dtype))
+
+ return shape_name
+
+
+@unique
+class QuantType(IntEnum):
+ UNKNOWN = 0
+ ALL_I8 = 1
+ ALL_U8 = 2
+ ALL_I16 = 3
+ # TODO: support QUINT16
+ CONV_U8_U8 = 4
+ CONV_I8_I8 = 5
+ CONV_I8_I4 = 6
+ CONV_I16_I8 = 7
+
+
+def get_tf_dtype(quantized_inference_dtype):
+ if quantized_inference_dtype == QuantType.ALL_I8:
+ return tf.qint8
+ elif quantized_inference_dtype == QuantType.ALL_U8:
+ return tf.quint8
+ elif quantized_inference_dtype == QuantType.ALL_I16:
+ return tf.qint16
+ elif quantized_inference_dtype == QuantType.CONV_U8_U8:
+ return tf.quint8
+ elif quantized_inference_dtype == QuantType.CONV_I8_I8:
+ return tf.qint8
+ elif quantized_inference_dtype == QuantType.CONV_I8_I4:
+ return tf.qint8
+ elif quantized_inference_dtype == QuantType.CONV_I16_I8:
+ return tf.qint16
+ else:
+ return None
+
+
+class TensorScale:
+ def __init__(self, _min, _max, _num_bits, _narrow_range):
+ self.min = _min
+ self.max = _max
+ self.num_bits = _num_bits
+ self.narrow_range = _narrow_range