aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py48
1 files changed, 36 insertions, 12 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index a768da0..7fef942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
+from generator.tosa_utils import get_rank_mismatch_shape
from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
@@ -1263,6 +1264,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1369,6 +1371,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1404,6 +1407,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1438,6 +1442,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -3657,6 +3662,8 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongRank,
),
},
"reshape": {
@@ -3699,7 +3706,7 @@ class TosaTestGen:
"slice": {
"op": Op.SLICE,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_slice,
TosaTensorGen.tgBasic,
@@ -3718,11 +3725,13 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
),
},
"tile": {
"op": Op.TILE,
"operands": (1, 0),
+ "rank": (1, 6),
"build_fcn": (
build_tile,
TosaTensorGen.tgBasic,
@@ -3735,12 +3744,14 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongRank,
),
},
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
@@ -3755,6 +3766,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evTensorSizeInputOutputMismatch,
),
},
# Data nodes
@@ -4539,6 +4553,8 @@ class OutputShaper:
if error_name == ErrorIf.PadOutputShapeMismatch:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] -= rng.choice([1, 2])
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4583,7 +4599,7 @@ class OutputShaper:
return ser.addOutput(output_shape, outputDType)
@staticmethod
- def sliceOp(ser, rng, a, start, size, error_name=None):
+ def sliceOp(ser, rng, input, start, size, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4595,13 +4611,13 @@ class OutputShaper:
DType.FP16,
DType.BF16,
]
- wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
- outputDType = a.dtype
+ outputDType = input.dtype
+ output_shape = size.copy()
if error_name == ErrorIf.SizeOutputShapeMismatch:
- output_shape = size.copy()
for index in range(len(output_shape)):
if output_shape[index] <= 2:
output_shape[index] = output_shape[index] + rng.choice([1, 2])
@@ -4609,8 +4625,10 @@ class OutputShaper:
output_shape[index] = output_shape[index] + rng.choice(
[-2, -1, 1, 2]
)
- else:
- output_shape = size.copy()
+ elif error_name == ErrorIf.InputSizeStartLengthMismatch:
+ output_shape = input.shape.copy()
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
return ser.addOutput(output_shape, outputDType)
@@ -4623,6 +4641,9 @@ class OutputShaper:
for i in range(len(output_shape)):
output_shape[i] = a.shape[i] * multiples[i]
+ if error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4646,13 +4667,16 @@ class OutputShaper:
assert len(perms) == len(output_shape)
- if error_name == ErrorIf.IndexOutsideBounds:
- for i in range(len(output_shape)):
- output_shape[i] = a.shape[0]
- else:
+ if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
for i in range(len(output_shape)):
output_shape[i] = a.shape[perms[i]]
+ if error_name == ErrorIf.TensorSizeInputOutputMismatch:
+ for i in range(len(output_shape)):
+ output_shape[i] += rng.integers(1, 10)
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,