aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py51
1 files changed, 34 insertions, 17 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index f2f9b63..2566f69 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -593,7 +593,7 @@ class TosaArgGen:
def getFactors(val, start=1):
factors = []
- for i in range(start, int(np.sqrt(val))):
+ for i in range(start, int(np.sqrt(val)) + 1):
if (val % i) == 0:
factors.append(i)
@@ -614,27 +614,44 @@ class TosaArgGen:
for p in range(testGen.args.num_rand_permutations):
newRank = testGen.randInt(1, 6)
- newShape = []
if len(factors) < newRank:
continue
- remainingElements = totalElements
- shuffledFactors = testGen.rng.permutation(factors)
- for i in range(newRank):
- # pick rank-1 factors
- newShape.append(shuffledFactors[0])
- remainingElements = remainingElements // shuffledFactors[0]
- shuffledFactors = testGen.rng.permutation(
- TosaArgGen.getFactors(remainingElements)
- )
- newShape.append(remainingElements)
+ found = True
+ # escape_counter breaks while loop if it continues on for too long
+ escape_counter = 0
+ while found:
+ newShape = []
+ # Generate newShape ensuring it isn't a duplicate
+ remainingElements = totalElements
+ shuffledFactors = testGen.rng.permutation(factors)
+ for i in range(newRank):
+ # pick rank-1 factors
+ newShape.append(shuffledFactors[0])
+ remainingElements = remainingElements // shuffledFactors[0]
+ shuffledFactors = testGen.rng.permutation(
+ TosaArgGen.getFactors(remainingElements)
+ )
+ newShape.append(remainingElements)
+
+ # Toss in a -1 sometimes
+ minusOne = testGen.randInt(0, newRank * 4)
+ if minusOne < newRank:
+ newShape[minusOne] = -1
+
+ # Check for duplicates
+ found = False
+ for name, other_shape in arg_list:
+ if other_shape[0] == newShape:
+ found = True
+ break
- # Toss in a -1 sometimes
- minusOne = testGen.randInt(0, newRank * 4)
- if minusOne < newRank:
- newShape[minusOne] = -1
+ escape_counter += 1
+ if escape_counter >= 100:
+ break
- arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
+ if not found:
+ arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
return arg_list