diff options
author | Matthew Haddon <matthew.haddon@arm.com> | 2021-06-22 16:55:23 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2021-07-07 00:04:23 +0000 |
commit | 2ad047dcc814b8edb519efb5472ab03fbc30e9e5 (patch) | |
tree | 3dd158ef888fe2b9fd2c2e61e528f1c22146b8bd | |
parent | d5934146b0b4e18bb6bad213901e48a5a20bef00 (diff) | |
download | reference_model-2ad047dcc814b8edb519efb5472ab03fbc30e9e5.tar.gz |
Fix bug causing identical reshape permutations
* When generating permutations of a reshape operator test there was
no check to ensure that the permutation was unique, this patch adds
a check to ensure that no two newShape variables are the same.
* Added a 'escape_counter' which will break out of while loop
if it continues on for too long.
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Change-Id: I231eb9b546a73431835b5dc899784f69cc22a773
-rw-r--r-- | verif/tosa_test_gen.py | 51 |
1 files changed, 34 insertions, 17 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index f2f9b63..2566f69 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -593,7 +593,7 @@ class TosaArgGen: def getFactors(val, start=1): factors = [] - for i in range(start, int(np.sqrt(val))): + for i in range(start, int(np.sqrt(val)) + 1): if (val % i) == 0: factors.append(i) @@ -614,27 +614,44 @@ class TosaArgGen: for p in range(testGen.args.num_rand_permutations): newRank = testGen.randInt(1, 6) - newShape = [] if len(factors) < newRank: continue - remainingElements = totalElements - shuffledFactors = testGen.rng.permutation(factors) - for i in range(newRank): - # pick rank-1 factors - newShape.append(shuffledFactors[0]) - remainingElements = remainingElements // shuffledFactors[0] - shuffledFactors = testGen.rng.permutation( - TosaArgGen.getFactors(remainingElements) - ) - newShape.append(remainingElements) + found = True + # escape_counter breaks while loop if it continues on for too long + escape_counter = 0 + while found: + newShape = [] + # Generate newShape ensuring it isn't a duplicate + remainingElements = totalElements + shuffledFactors = testGen.rng.permutation(factors) + for i in range(newRank): + # pick rank-1 factors + newShape.append(shuffledFactors[0]) + remainingElements = remainingElements // shuffledFactors[0] + shuffledFactors = testGen.rng.permutation( + TosaArgGen.getFactors(remainingElements) + ) + newShape.append(remainingElements) + + # Toss in a -1 sometimes + minusOne = testGen.randInt(0, newRank * 4) + if minusOne < newRank: + newShape[minusOne] = -1 + + # Check for duplicates + found = False + for name, other_shape in arg_list: + if other_shape[0] == newShape: + found = True + break - # Toss in a -1 sometimes - minusOne = testGen.randInt(0, newRank * 4) - if minusOne < newRank: - newShape[minusOne] = -1 + escape_counter += 1 + if escape_counter >= 100: + break - arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape])) + if not found: + arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape])) return arg_list |