From 2ad047dcc814b8edb519efb5472ab03fbc30e9e5 Mon Sep 17 00:00:00 2001 From: Matthew Haddon Date: Tue, 22 Jun 2021 16:55:23 +0100 Subject: Fix bug causing identical reshape permutations * When generating permutations of a reshape operator test there was no check to ensure that the permutation was unique, this patch adds a check to ensure that no two newShape variables are the same. * Added a 'escape_counter' which will break out of while loop if it continues on for too long. Signed-off-by: Matthew Haddon Change-Id: I231eb9b546a73431835b5dc899784f69cc22a773 --- verif/tosa_test_gen.py | 51 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index f2f9b63..2566f69 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -593,7 +593,7 @@ class TosaArgGen: def getFactors(val, start=1): factors = [] - for i in range(start, int(np.sqrt(val))): + for i in range(start, int(np.sqrt(val)) + 1): if (val % i) == 0: factors.append(i) @@ -614,27 +614,44 @@ class TosaArgGen: for p in range(testGen.args.num_rand_permutations): newRank = testGen.randInt(1, 6) - newShape = [] if len(factors) < newRank: continue - remainingElements = totalElements - shuffledFactors = testGen.rng.permutation(factors) - for i in range(newRank): - # pick rank-1 factors - newShape.append(shuffledFactors[0]) - remainingElements = remainingElements // shuffledFactors[0] - shuffledFactors = testGen.rng.permutation( - TosaArgGen.getFactors(remainingElements) - ) - newShape.append(remainingElements) + found = True + # escape_counter breaks while loop if it continues on for too long + escape_counter = 0 + while found: + newShape = [] + # Generate newShape ensuring it isn't a duplicate + remainingElements = totalElements + shuffledFactors = testGen.rng.permutation(factors) + for i in range(newRank): + # pick rank-1 factors + newShape.append(shuffledFactors[0]) + remainingElements = remainingElements // shuffledFactors[0] + shuffledFactors = testGen.rng.permutation( + TosaArgGen.getFactors(remainingElements) + ) + newShape.append(remainingElements) + + # Toss in a -1 sometimes + minusOne = testGen.randInt(0, newRank * 4) + if minusOne < newRank: + newShape[minusOne] = -1 + + # Check for duplicates + found = False + for name, other_shape in arg_list: + if other_shape[0] == newShape: + found = True + break - # Toss in a -1 sometimes - minusOne = testGen.randInt(0, newRank * 4) - if minusOne < newRank: - newShape[minusOne] = -1 + escape_counter += 1 + if escape_counter >= 100: + break - arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape])) + if not found: + arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape])) return arg_list -- cgit v1.2.1