From 5dd5a55bc00d0eaf9aa38511cf553b0d78dfed51 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 23 May 2023 22:41:20 +0000 Subject: Support custom_shapes for framework test generation Signed-off-by: Jerry Ge Change-Id: Ia29c73cb5d0a7f91914e2a94ca52d06f375722e9 --- verif/frameworks/tosa_verif_framework_generator.py | 25 +++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) (limited to 'verif/frameworks') diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index fffb842..a81ff6f 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -57,11 +57,12 @@ TYPE_FHIB = [tf.float32, tf.float16, tf.int32, tf.bool] # be used to restrict an operator to a particular framework. # # And optional members: -# 'template': boolean (indicates that this is a templated op which gets further -# processing in createDynamicOpLists) -# 'bias': boolean indicating that there is a bias component to be generated -# 'qtypes': List of QuantType quantized types to generate for this op -# 'rank': tuple (lowest rank, highest rank). Dimension range of input tensor. +# 'template': boolean (indicates that this is a templated op which gets further +# processing in createDynamicOpLists) +# 'bias': boolean indicating that there is a bias component to be generated +# 'qtypes': List of QuantType quantized types to generate for this op +# 'rank': tuple (lowest rank, highest rank). Dimension range of input tensor. +# 'custom_shapes': List of custom shapes for specific operators TF_OP_LIST = { "add": { @@ -783,6 +784,10 @@ TF_OP_LIST = { TYPE_F + [QuantType.ALL_U8, QuantType.ALL_I8, QuantType.ALL_I16] ), }, + "custom_shapes": { + "custom_shape_only": False, + "shape_list": [(3, 1, 1, 7)], + }, }, "left_shift": { "operands": (1, 0), @@ -1409,6 +1414,16 @@ def generate_op_tests(args, op_name, shape_list, result_name, filter, unit_test_ nonquantized_dtypes = list(nonquantized_dtypes_set) quantized_dtypes = tflite_quantized_dtypes + # append custom_shapes or replace shape_list with custom_shapes + try: + custom_shapes = op["custom_shapes"] + if custom_shapes["custom_shape_only"]: + shape_list = custom_shapes["shape_list"] + else: + shape_list.append(custom_shapes) + except KeyError: + pass + # populate non quantized unit test arguments for dtype in nonquantized_dtypes: -- cgit v1.2.1