aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2021-10-07 12:06:00 +0100
committerEric Kunze <eric.kunze@arm.com>2021-10-11 14:04:25 +0000
commit03bec734775dad8e45145b4d0dae4584199c1b84 (patch)
tree386f008654f5b027582adf5c8d4207bf8cf2902a
parente4ecdb2ee8471cc713e7562fbec4118820f81a72 (diff)
downloadreference_model-03bec734775dad8e45145b4d0dae4584199c1b84.tar.gz
Fix rank and dtype filtering for ops like conv3d & fully_connected
Change-Id: Ic2aebe40b5cce61d4576a64f4f48ff87b36475c2 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r--verif/tosa_test_gen.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 2478331..8d69831 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -2694,21 +2694,19 @@ class TosaTestGen:
if rank >= rmin and rank <= rmax:
cleanRankFilter.append(rank)
elif rankFilter is None and shapeFilter[0] is None:
- cleanRankFilter = []
- # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
- rankRange = range(rmin, rmax + 1)
- for rank in rankRange:
- if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
- cleanRankFilter.append(rank)
+ # Ensure default behaviour is bounded by default range or by operator,
+ # whichever is the smaller range of ranks.
+ opRankRange = range(rmin, rmax + 1)
+ cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
else:
cleanRankFilter = range(rmin, rmax + 1)
dtypes = op["types"]
if dtypeFilter is not None:
cleanDtypeFilter = []
- # Ensure filtered dtypes are allowed by operator
- for dtype in dtypeFilter:
- if dtype in dtypes:
+ # Create list of operator dtypes filtered by requested dtypes
+ for dtype in dtypes:
+ if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
cleanDtypeFilter.append(dtype)
else:
cleanDtypeFilter = dtypes