diff options
-rw-r--r-- | verif/tosa_test_gen.py | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 6f9acf4..f2f9b63 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -29,6 +29,7 @@ import queue import threading import traceback import math +import itertools from enum import IntEnum, Enum, unique @@ -643,20 +644,17 @@ class TosaArgGen: ifm_shape = shapeList[0] - perms = range(len(ifm_shape)) - for p in range(testGen.args.num_rand_permutations): - perms = np.int32(testGen.rng.permutation(perms)).tolist() + # Get all permutations + permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] - # Avoid duplicates - found = False - for name, other_perm in arg_list: - if other_perm[0] == perms: - found = True - break + # Limit to possible permutations from shape dimension or argument setting + limit = min(len(permutations), testGen.args.num_rand_permutations) - if not found: - arg_list.append(("perm{}".format(p), [perms])) + # Get random permutation generator that uses all permutations + random_permutations = testGen.rng.permutation(permutations) + # Create list of required amount of permutations + arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)] return arg_list @staticmethod @@ -2327,7 +2325,7 @@ class TosaTestGen: "transpose": { "op": Op.TRANSPOSE, "operands": (1, 0), - "rank": (2, 4), # Do not allow tranpose on rank=1 + "rank": (1, 4), "build_fcn": ( build_transpose, TosaTensorGen.tgBasic, |