aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-05-23 22:41:20 +0000
committerEric Kunze <eric.kunze@arm.com>2023-06-02 21:54:46 +0000
commit5dd5a55bc00d0eaf9aa38511cf553b0d78dfed51 (patch)
tree96d85d4f480b45c913c98b02c4c209ba4dae5dc2
parenteb52cc18b342d6329322f84b671eab4450e663fd (diff)
downloadreference_model-5dd5a55bc00d0eaf9aa38511cf553b0d78dfed51.tar.gz
Support custom_shapes for framework test generation
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ia29c73cb5d0a7f91914e2a94ca52d06f375722e9
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index fffb842..a81ff6f 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -57,11 +57,12 @@ TYPE_FHIB = [tf.float32, tf.float16, tf.int32, tf.bool]
# be used to restrict an operator to a particular framework.
#
# And optional members:
-# 'template': boolean (indicates that this is a templated op which gets further
-# processing in createDynamicOpLists)
-# 'bias': boolean indicating that there is a bias component to be generated
-# 'qtypes': List of QuantType quantized types to generate for this op
-# 'rank': tuple (lowest rank, highest rank). Dimension range of input tensor.
+# 'template': boolean (indicates that this is a templated op which gets further
+# processing in createDynamicOpLists)
+# 'bias': boolean indicating that there is a bias component to be generated
+# 'qtypes': List of QuantType quantized types to generate for this op
+# 'rank': tuple (lowest rank, highest rank). Dimension range of input tensor.
+# 'custom_shapes': List of custom shapes for specific operators
TF_OP_LIST = {
"add": {
@@ -783,6 +784,10 @@ TF_OP_LIST = {
TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16]
),
},
+ "custom_shapes": {
+ "custom_shape_only": False,
+ "shape_list": [(3, 1, 1, 7)],
+ },
},
"left_shift": {
"operands": (1, 0),
@@ -1409,6 +1414,16 @@ def generate_op_tests(args, op_name, shape_list, result_name, filter, unit_test_
nonquantized_dtypes = list(nonquantized_dtypes_set)
quantized_dtypes = tflite_quantized_dtypes
+ # append custom_shapes or replace shape_list with custom_shapes
+ try:
+ custom_shapes = op["custom_shapes"]
+ if custom_shapes["custom_shape_only"]:
+ shape_list = custom_shapes["shape_list"]
+ else:
+ shape_list.append(custom_shapes)
+ except KeyError:
+ pass
+
# populate non quantized unit test arguments
for dtype in nonquantized_dtypes: