From 8746026555c37c8d208fcbedeb04d0ae6d25d53e Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Mon, 25 Mar 2024 09:46:02 +0000 Subject: Fix missing Pooling ERRORIFs Signed-off-by: Jeremy Johnson Change-Id: I749b1b9fbfa32954d8748a860280c86087c08e7f --- verif/generator/tosa_arg_gen.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 504bfa9..41ef4df 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1979,7 +1979,7 @@ class TosaArgGen: return sparsity # Maximum number of error_if variants to produce - MAX_CONV_ERROR_IFS = 3 + MAX_TESTS_ERROR_IFS = 3 @staticmethod def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None): @@ -2252,7 +2252,7 @@ class TosaArgGen: ) if ( error_name - and len(arg_list) >= TosaArgGen.MAX_CONV_ERROR_IFS + and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS ): # Found enough errors logger.debug( @@ -2691,12 +2691,15 @@ class TosaArgGen: {x for x in itertools.product(*([[startPad, bigPadding]] * 4))} ) - # There are too many parameter combinations, so generate them sparsely, - # very sparse for negative tests - sparsity_factor = 2 if error_name else 500 - sparsity = ( - len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 - ) + if error_name: + # Cycle through all error_if tests but we only keep the first few + sparsity = 1 + else: + # There are too many parameter combinations, so generate them sparsely + sparsity_factor = 500 + sparsity = ( + len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 + ) else: # We have already limited test output combinations for 8k tests sparsity = 1 @@ -2732,6 +2735,7 @@ class TosaArgGen: args_dict["acc_type"] = accum return (arg_str.format(*arg_str_elems), args_dict) + more_tests = True n = 0 for a in accum_dtypes: for s in sorted(list(strides)): @@ -2751,7 +2755,8 @@ class TosaArgGen: get_arg_list_element(a, sNew, pNew, kNew, shape) ) elif ( - n % sparsity == 0 + more_tests + and n % sparsity == 0 # padding must not exceed the kernel size and p[0] < k[0] and p[1] < k[0] @@ -2792,6 +2797,16 @@ class TosaArgGen: arg_list.append( get_arg_list_element(a, s, p, k, dp, shape) ) + if ( + error_name + and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS + ): + # Found enough errors + logger.debug( + f"Skipping creating more pooling error tests for {error_name}" + ) + more_tests = False + n += 1 # Now add data generator types -- cgit v1.2.1