aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py22
1 files changed, 10 insertions, 12 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 6f9acf4..f2f9b63 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -29,6 +29,7 @@ import queue
import threading
import traceback
import math
+import itertools
from enum import IntEnum, Enum, unique
@@ -643,20 +644,17 @@ class TosaArgGen:
ifm_shape = shapeList[0]
- perms = range(len(ifm_shape))
- for p in range(testGen.args.num_rand_permutations):
- perms = np.int32(testGen.rng.permutation(perms)).tolist()
+ # Get all permutations
+ permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
- # Avoid duplicates
- found = False
- for name, other_perm in arg_list:
- if other_perm[0] == perms:
- found = True
- break
+ # Limit to possible permutations from shape dimension or argument setting
+ limit = min(len(permutations), testGen.args.num_rand_permutations)
- if not found:
- arg_list.append(("perm{}".format(p), [perms]))
+ # Get random permutation generator that uses all permutations
+ random_permutations = testGen.rng.permutation(permutations)
+ # Create list of required amount of permutations
+ arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
return arg_list
@staticmethod
@@ -2327,7 +2325,7 @@ class TosaTestGen:
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
- "rank": (2, 4), # Do not allow tranpose on rank=1
+ "rank": (1, 4),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,