aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py41
1 files changed, 29 insertions, 12 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 9209d9c..e2a69f1 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1043,6 +1043,18 @@ class TosaArgGen:
return axes
@staticmethod
+ def _calculate_sparsity(num_tests, sparsity_factor):
+ sparsity = num_tests // sparsity_factor + 1
+ # If there are only a small number of tests, just select them all
+ if sparsity < 13:
+ sparsity = 1
+ # To get a variety of parameter combinations sparsity should not be a
+ # multiple of 2, 3 or 5
+ while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
+ sparsity += 1
+ return sparsity
+
+ @staticmethod
def agConv(testGen, opName, shapeList, dtypes, error_name=None):
arg_list = []
@@ -1101,14 +1113,9 @@ class TosaArgGen:
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
sparsity_factor = 2 if error_name else 120
- sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
- # If there are only a small number of tests, just select them all
- if sparsity < 13:
- sparsity = 1
- # To get a variety of parameter combinations sparsity should not be a
- # multiple of 2, 3 or 5
- while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
- sparsity += 1
+ sparsity = TosaArgGen._calculate_sparsity(
+ len(paddings) * len(strides) * len(dilations), sparsity_factor
+ )
n = 0
for s in sorted(list(strides)):
@@ -1311,7 +1318,17 @@ class TosaArgGen:
else:
return []
- for paddings in shape_pad_values:
+ list_shape_pad_values = list(shape_pad_values)
+ # If we are producing tests for rank 6 or greater use sparsity
+ if len(list_shape_pad_values) > 1024:
+ sparsity_factor = 2 if error_name else 120
+ sparsity = TosaArgGen._calculate_sparsity(
+ len(list_shape_pad_values), sparsity_factor
+ )
+ else:
+ sparsity = 1
+
+ for n, paddings in enumerate(list_shape_pad_values):
paddings = list(paddings)
args_valid = True
@@ -1325,8 +1342,7 @@ class TosaArgGen:
paddings[i] = (0, 0)
if all([p > -1 for p in paddings[i]]):
args_valid = False
-
- if args_valid:
+ if args_valid and n % sparsity == 0:
name = "pad"
for r in range(rank):
before, after = paddings[r]
@@ -1670,7 +1686,8 @@ class TosaArgGen:
factors = TosaArgGen.getFactors(totalElements)
for p in range(testGen.args.num_rand_permutations):
- newRank = testGen.randInt(1, 7)
+ # Rank from 1 to TOSA_TENSOR_MAX_RANK
+ newRank = testGen.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
if len(factors) < newRank:
continue