diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 41 |
1 files changed, 29 insertions, 12 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 9209d9c..e2a69f1 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1043,6 +1043,18 @@ class TosaArgGen: return axes @staticmethod + def _calculate_sparsity(num_tests, sparsity_factor): + sparsity = num_tests // sparsity_factor + 1 + # If there are only a small number of tests, just select them all + if sparsity < 13: + sparsity = 1 + # To get a variety of parameter combinations sparsity should not be a + # multiple of 2, 3 or 5 + while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: + sparsity += 1 + return sparsity + + @staticmethod def agConv(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] @@ -1101,14 +1113,9 @@ class TosaArgGen: # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests sparsity_factor = 2 if error_name else 120 - sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1 - # If there are only a small number of tests, just select them all - if sparsity < 13: - sparsity = 1 - # To get a variety of parameter combinations sparsity should not be a - # multiple of 2, 3 or 5 - while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: - sparsity += 1 + sparsity = TosaArgGen._calculate_sparsity( + len(paddings) * len(strides) * len(dilations), sparsity_factor + ) n = 0 for s in sorted(list(strides)): @@ -1311,7 +1318,17 @@ class TosaArgGen: else: return [] - for paddings in shape_pad_values: + list_shape_pad_values = list(shape_pad_values) + # If we are producing tests for rank 6 or greater use sparsity + if len(list_shape_pad_values) > 1024: + sparsity_factor = 2 if error_name else 120 + sparsity = TosaArgGen._calculate_sparsity( + len(list_shape_pad_values), sparsity_factor + ) + else: + sparsity = 1 + + for n, paddings in enumerate(list_shape_pad_values): paddings = list(paddings) args_valid = True @@ -1325,8 +1342,7 @@ class TosaArgGen: paddings[i] = (0, 0) if all([p > -1 for p in paddings[i]]): args_valid = False - - if args_valid: + if args_valid and n % sparsity == 0: name = "pad" for r in range(rank): before, after = paddings[r] @@ -1670,7 +1686,8 @@ class TosaArgGen: factors = TosaArgGen.getFactors(totalElements) for p in range(testGen.args.num_rand_permutations): - newRank = testGen.randInt(1, 7) + # Rank from 1 to TOSA_TENSOR_MAX_RANK + newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1)) if len(factors) < newRank: continue |