diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2021-06-21 15:55:35 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2021-06-28 17:29:48 +0000 |
commit | a618557ba688fba16eefd7d9f4c10cc11cb085ef (patch) | |
tree | 7733c56efffe0f8e4b14c4d97ad27f9d8b43df86 | |
parent | 82507d77056dd5510547438ba2064c1ee8bebc2c (diff) | |
download | reference_model-a618557ba688fba16eefd7d9f4c10cc11cb085ef.tar.gz |
Fix transpose test gen of permutations & rank
Change transpose permutation generation to limit to the number of
possible permutations that can be created by the shape size or the
argument setting which ever is smaller. Also make sure all
permutations up to this number are generated rather than randomly
skipped due to duplicates.
Allow rank 1 transpose tests as the specification allows rank 1.
Change-Id: I28ea64c1d819f3af72c97bed43cfe7279c7e2f9c
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r-- | verif/tosa_test_gen.py | 22 |
1 files changed, 10 insertions, 12 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 6f9acf4..f2f9b63 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -29,6 +29,7 @@ import queue import threading import traceback import math +import itertools from enum import IntEnum, Enum, unique @@ -643,20 +644,17 @@ class TosaArgGen: ifm_shape = shapeList[0] - perms = range(len(ifm_shape)) - for p in range(testGen.args.num_rand_permutations): - perms = np.int32(testGen.rng.permutation(perms)).tolist() + # Get all permutations + permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] - # Avoid duplicates - found = False - for name, other_perm in arg_list: - if other_perm[0] == perms: - found = True - break + # Limit to possible permutations from shape dimension or argument setting + limit = min(len(permutations), testGen.args.num_rand_permutations) - if not found: - arg_list.append(("perm{}".format(p), [perms])) + # Get random permutation generator that uses all permutations + random_permutations = testGen.rng.permutation(permutations) + # Create list of required amount of permutations + arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)] return arg_list @staticmethod @@ -2327,7 +2325,7 @@ class TosaTestGen: "transpose": { "op": Op.TRANSPOSE, "operands": (1, 0), - "rank": (2, 4), # Do not allow tranpose on rank=1 + "rank": (1, 4), "build_fcn": ( build_transpose, TosaTensorGen.tgBasic, |