aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2021-06-21 15:55:35 +0100
committerEric Kunze <eric.kunze@arm.com>2021-06-28 17:29:48 +0000
commita618557ba688fba16eefd7d9f4c10cc11cb085ef (patch)
tree7733c56efffe0f8e4b14c4d97ad27f9d8b43df86 /verif
parent82507d77056dd5510547438ba2064c1ee8bebc2c (diff)
downloadreference_model-a618557ba688fba16eefd7d9f4c10cc11cb085ef.tar.gz
Fix transpose test gen of permutations & rank
Change transpose permutation generation to limit to the number of possible permutations that can be created by the shape size or the argument setting which ever is smaller. Also make sure all permutations up to this number are generated rather than randomly skipped due to duplicates. Allow rank 1 transpose tests as the specification allows rank 1. Change-Id: I28ea64c1d819f3af72c97bed43cfe7279c7e2f9c Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'verif')
-rw-r--r--verif/tosa_test_gen.py22
1 files changed, 10 insertions, 12 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 6f9acf4..f2f9b63 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -29,6 +29,7 @@ import queue
import threading
import traceback
import math
+import itertools
from enum import IntEnum, Enum, unique
@@ -643,20 +644,17 @@ class TosaArgGen:
ifm_shape = shapeList[0]
- perms = range(len(ifm_shape))
- for p in range(testGen.args.num_rand_permutations):
- perms = np.int32(testGen.rng.permutation(perms)).tolist()
+ # Get all permutations
+ permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
- # Avoid duplicates
- found = False
- for name, other_perm in arg_list:
- if other_perm[0] == perms:
- found = True
- break
+ # Limit to possible permutations from shape dimension or argument setting
+ limit = min(len(permutations), testGen.args.num_rand_permutations)
- if not found:
- arg_list.append(("perm{}".format(p), [perms]))
+ # Get random permutation generator that uses all permutations
+ random_permutations = testGen.rng.permutation(permutations)
+ # Create list of required amount of permutations
+ arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
return arg_list
@staticmethod
@@ -2327,7 +2325,7 @@ class TosaTestGen:
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
- "rank": (2, 4), # Do not allow tranpose on rank=1
+ "rank": (1, 4),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,