aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py32
1 files changed, 29 insertions, 3 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c867070..28b3d28 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -36,9 +36,7 @@ logger = logging.getLogger("tosa_verif_build_tests")
class TosaTestGen:
- # Maximum rank of tensor supported by test generator.
# This currently matches the 8K level defined in the specification.
- TOSA_TENSOR_MAX_RANK = 6
TOSA_8K_LEVEL_MAX_SCALE = 64
TOSA_8K_LEVEL_MAX_KERNEL = 8192
TOSA_8K_LEVEL_MAX_STRIDE = 8192
@@ -2941,8 +2939,10 @@ class TosaTestGen:
testList = []
if testType == "negative" and "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
+ num_error_types_created = 0
else:
error_if_validators = [None]
+ num_error_types_created = None
for validator in error_if_validators:
if validator is not None:
@@ -3020,6 +3020,21 @@ class TosaTestGen:
testList.append(
(opName, testStr, t, error_name, shapeList, args)
)
+ if error_name is not None:
+ # Check the last test is of the error we wanted
+ if len(testList) == 0 or testList[-1][3] != error_name:
+ if self.args.level8k:
+ logger.info(f"Missing {error_name} tests due to level8k mode")
+ else:
+ logger.error(f"ERROR: Failed to create any {error_name} tests")
+ logger.debug(
+ "Last test created: {}".format(
+ testList[-1] if testList else None
+ )
+ )
+ else:
+ # Successfully created at least one ERRROR_IF test
+ num_error_types_created += 1
if testType == "positive":
# Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
@@ -3039,6 +3054,15 @@ class TosaTestGen:
if not remove_test:
clean_testList.append(test)
testList = clean_testList
+ else:
+ if num_error_types_created is not None and not self.args.level8k:
+ remaining_error_types = (
+ len(error_if_validators) - num_error_types_created
+ )
+ if remaining_error_types:
+ raise Exception(
+ f"Failed to create {remaining_error_types} error types for {opName}"
+ )
return testList
@@ -3141,9 +3165,11 @@ class TosaTestGen:
if compliance:
tensMeta["compliance"] = compliance
self.serialize("test", tensMeta)
+ return True
else:
# The test is not valid
logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
+ return False
def createDynamicOpLists(self):
@@ -3305,7 +3331,7 @@ class TosaTestGen:
[DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
- DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
+ DEFAULT_RANK_RANGE = (1, gtu.MAX_TENSOR_RANK)
TOSA_OP_LIST = {
# Tensor operators