aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-03-28 15:53:21 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-04-11 15:02:57 +0100
commit18a379d99ad10002b3cf6eda086457179221cc22 (patch)
tree9b90a31f846035236cbecb9cde379dee66b6f0c3 /verif/generator/tosa_test_gen.py
parent3f3de01fa87246161e47c15fd6c44f710b86f3e7 (diff)
downloadreference_model-18a379d99ad10002b3cf6eda086457179221cc22.tar.gz
Add rank 0 testing support
Default test range is now rank 0 to 3 instead of 1 to 4 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ibde66b60b58de9f4a3852a3807c01f8dae61206f
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py116
1 files changed, 68 insertions, 48 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 399fed6..c5ac0f9 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -186,25 +186,26 @@ class TosaTestGen:
def makeShape(self, rng, rank):
if self.targetted_shape:
return np.int32(self.targetted_shape)
- return np.int32(
- rng.integers(
- low=self.args.tensor_shape_range[0],
- high=self.args.tensor_shape_range[1],
- size=rank,
+ else:
+ return np.int32(
+ rng.integers(
+ low=self.args.tensor_shape_range[0],
+ high=self.args.tensor_shape_range[1],
+ size=rank,
+ )
)
- )
def setTargetShape(self, shape):
self.targetted_shape = shape
def shapeStr(self, shape):
-
- sStr = []
- # Convert to strings
- for i in shape:
- sStr.append(str(i))
-
- return "x".join(sStr)
+ assert shape is not None
+ if len(shape) > 0:
+ # Rank > 0
+ return "x".join([str(d) for d in shape])
+ else:
+ # Rank 0
+ return "0"
def typeStr(self, dtype):
if isinstance(dtype, list) or isinstance(dtype, tuple):
@@ -2839,29 +2840,36 @@ class TosaTestGen:
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
- # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
- default_test_rank_range = range(1, 5)
- if not shapeFilter:
- shapeFilter = [None]
+ # Create a default testing rank range
+ if testType == "positive":
+ # 0-3 inclusive to keep test sizes reasonably small.
+ default_test_rank_range = range(0, 4)
+ else:
+ # Some errors do not work with rank 0, use 1-3
+ default_test_rank_range = range(1, 4)
# Calculate the filters based on what is requested and what the operator allows
rmin, rmax = op["rank"]
- if rankFilter is not None:
- cleanRankFilter = []
- # Ensure rankFilter values are allowed by operator
- for rank in rankFilter:
- if rank >= rmin and rank <= rmax:
- cleanRankFilter.append(rank)
- elif rankFilter is None and shapeFilter[0] is None:
- # Ensure default behaviour is bounded by default range or by operator,
- # whichever is the smaller range of ranks.
- opRankRange = range(rmin, rmax + 1)
- cleanRankFilter = (
- opRankRange
- if len(opRankRange) <= len(default_test_rank_range)
- else default_test_rank_range
- )
+
+ if shapeFilter:
+ # Specified shapes - ignore rank filter and default to op ranks below
+ rankFilter = None
+ ranksToCheck = []
+ elif rankFilter is None:
+ # No set rank filter so ensure default behaviour is bounded
+ ranksToCheck = default_test_rank_range
else:
+ ranksToCheck = rankFilter
+
+ cleanRankFilter = []
+ # Ensure rank values are allowed by operator
+ for rank in ranksToCheck:
+ if rank >= rmin and rank <= rmax:
+ cleanRankFilter.append(rank)
+
+ if shapeFilter or (len(cleanRankFilter) == 0 and rankFilter is None):
+ # Shapes specified or default test ranks didn't meet
+ # op requirements - so just use op ranks
cleanRankFilter = range(rmin, rmax + 1)
dtypes = op["types"]
@@ -2877,6 +2885,9 @@ class TosaTestGen:
else:
cleanDtypeFilter = dtypes
+ if not shapeFilter:
+ shapeFilter = [None]
+
if testType == "positive":
filterDict = {
"shapeFilter": shapeFilter,
@@ -3326,7 +3337,7 @@ class TosaTestGen:
[DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
- DEFAULT_RANK_RANGE = (1, gtu.MAX_TENSOR_RANK)
+ DEFAULT_RANK_RANGE = (0, gtu.MAX_TENSOR_RANK)
KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]]
KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]]
@@ -3348,7 +3359,7 @@ class TosaTestGen:
"argmax": {
"op": Op.ARGMAX,
"operands": (1, 0),
- "rank": (1, 6),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_argmax,
TosaTensorGen.tgBasic,
@@ -4519,6 +4530,7 @@ class TosaTestGen:
"pad": {
"op": Op.PAD,
"operands": (2, 0),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_pad,
TosaTensorGen.tgBasic,
@@ -4541,6 +4553,7 @@ class TosaTestGen:
"dim": {
"op": Op.DIM,
"operands": (1, 0),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_dim,
TosaTensorGen.tgBasic,
@@ -4560,6 +4573,7 @@ class TosaTestGen:
"reshape": {
"op": Op.RESHAPE,
"operands": (2, 0),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_reshape,
TosaTensorGen.tgBasic,
@@ -4599,7 +4613,7 @@ class TosaTestGen:
"slice": {
"op": Op.SLICE,
"operands": (3, 0),
- "rank": (1, 6),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_slice,
TosaTensorGen.tgBasic,
@@ -4629,7 +4643,7 @@ class TosaTestGen:
"tile": {
"op": Op.TILE,
"operands": (2, 0),
- "rank": (1, 6),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_tile,
TosaTensorGen.tgBasic,
@@ -4650,7 +4664,7 @@ class TosaTestGen:
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
- "rank": (1, 6),
+ "rank": (1, gtu.MAX_TENSOR_RANK),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
@@ -5047,6 +5061,7 @@ class OutputShaper:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
+ # Work out broadcasted output shape (when not ERRORIF test)
shape = []
for i in range(len(a.shape)):
if a.shape[i] == 1 and error_name is None:
@@ -5054,8 +5069,9 @@ class OutputShaper:
else:
shape.append(a.shape[i])
- fuzz_idx = rng.integers(0, len(a.shape))
- if error_name == ErrorIf.DimensionMismatch:
+ if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
+ # Can only create this error for rank > 0
+ fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
@@ -5112,6 +5128,7 @@ class OutputShaper:
assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
assert a.dtype == b.dtype
+ # Work out broadcasted output shape (when not ERRORIF test)
shape = []
for i in range(len(cond.shape)):
if cond.shape[i] == 1 and error_name is None:
@@ -5119,8 +5136,9 @@ class OutputShaper:
else:
shape.append(cond.shape[i])
- fuzz_idx = rng.integers(0, len(a.shape))
- if error_name == ErrorIf.DimensionMismatch:
+ if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
+ # Can only create this error for rank > 0
+ fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
@@ -5146,7 +5164,7 @@ class OutputShaper:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
- # Do broadcast
+ # Work out broadcasted output shape
shape = []
for i in range(len(a.shape)):
if a.shape[i] == 1 and len(b.shape) > i:
@@ -5154,8 +5172,9 @@ class OutputShaper:
else:
shape.append(a.shape[i])
- fuzz_idx = rng.integers(0, len(a.shape))
- if error_name == ErrorIf.DimensionMismatch:
+ if len(shape) > 0 and error_name == ErrorIf.DimensionMismatch:
+ # Can only create this error for rank > 0
+ fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType:
@@ -5994,12 +6013,13 @@ class OutputShaper:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
- shape = []
- for i in range(len(a.shape)):
- shape.append(a.shape[i])
+ shape = a.shape.copy()
- fuzz_idx = rng.integers(0, len(a.shape))
+ # Do not expect rank 0 tests!
+ assert len(shape) > 0
if error_name == ErrorIf.DimensionMismatch:
+ # Can only create this error for rank > 0
+ fuzz_idx = rng.integers(0, len(shape))
shape[fuzz_idx] += 1
if error_name == ErrorIf.WrongOutputType: