diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 33 |
1 files changed, 24 insertions, 9 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 504bfa9..41ef4df 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1979,7 +1979,7 @@ class TosaArgGen: return sparsity # Maximum number of error_if variants to produce - MAX_CONV_ERROR_IFS = 3 + MAX_TESTS_ERROR_IFS = 3 @staticmethod def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None): @@ -2252,7 +2252,7 @@ class TosaArgGen: ) if ( error_name - and len(arg_list) >= TosaArgGen.MAX_CONV_ERROR_IFS + and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS ): # Found enough errors logger.debug( @@ -2691,12 +2691,15 @@ class TosaArgGen: {x for x in itertools.product(*([[startPad, bigPadding]] * 4))} ) - # There are too many parameter combinations, so generate them sparsely, - # very sparse for negative tests - sparsity_factor = 2 if error_name else 500 - sparsity = ( - len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 - ) + if error_name: + # Cycle through all error_if tests but we only keep the first few + sparsity = 1 + else: + # There are too many parameter combinations, so generate them sparsely + sparsity_factor = 500 + sparsity = ( + len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 + ) else: # We have already limited test output combinations for 8k tests sparsity = 1 @@ -2732,6 +2735,7 @@ class TosaArgGen: args_dict["acc_type"] = accum return (arg_str.format(*arg_str_elems), args_dict) + more_tests = True n = 0 for a in accum_dtypes: for s in sorted(list(strides)): @@ -2751,7 +2755,8 @@ class TosaArgGen: get_arg_list_element(a, sNew, pNew, kNew, shape) ) elif ( - n % sparsity == 0 + more_tests + and n % sparsity == 0 # padding must not exceed the kernel size and p[0] < k[0] and p[1] < k[0] @@ -2792,6 +2797,16 @@ class TosaArgGen: arg_list.append( get_arg_list_element(a, s, p, k, dp, shape) ) + if ( + error_name + and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS + ): + # Found enough errors + logger.debug( + f"Skipping creating more pooling error tests for {error_name}" + ) + more_tests = False + n += 1 # Now add data generator types |