diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c867070..28b3d28 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -36,9 +36,7 @@ logger = logging.getLogger("tosa_verif_build_tests") class TosaTestGen: - # Maximum rank of tensor supported by test generator. # This currently matches the 8K level defined in the specification. - TOSA_TENSOR_MAX_RANK = 6 TOSA_8K_LEVEL_MAX_SCALE = 64 TOSA_8K_LEVEL_MAX_KERNEL = 8192 TOSA_8K_LEVEL_MAX_STRIDE = 8192 @@ -2941,8 +2939,10 @@ class TosaTestGen: testList = [] if testType == "negative" and "error_if_validators" in op: error_if_validators = op["error_if_validators"] + num_error_types_created = 0 else: error_if_validators = [None] + num_error_types_created = None for validator in error_if_validators: if validator is not None: @@ -3020,6 +3020,21 @@ class TosaTestGen: testList.append( (opName, testStr, t, error_name, shapeList, args) ) + if error_name is not None: + # Check the last test is of the error we wanted + if len(testList) == 0 or testList[-1][3] != error_name: + if self.args.level8k: + logger.info(f"Missing {error_name} tests due to level8k mode") + else: + logger.error(f"ERROR: Failed to create any {error_name} tests") + logger.debug( + "Last test created: {}".format( + testList[-1] if testList else None + ) + ) + else: + # Successfully created at least one ERRROR_IF test + num_error_types_created += 1 if testType == "positive": # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement @@ -3039,6 +3054,15 @@ class TosaTestGen: if not remove_test: clean_testList.append(test) testList = clean_testList + else: + if num_error_types_created is not None and not self.args.level8k: + remaining_error_types = ( + len(error_if_validators) - num_error_types_created + ) + if remaining_error_types: + raise Exception( + f"Failed to create {remaining_error_types} error types for {opName}" + ) return testList @@ -3141,9 +3165,11 @@ class TosaTestGen: if compliance: tensMeta["compliance"] = compliance self.serialize("test", tensMeta) + return True else: # The test is not valid logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}") + return False def createDynamicOpLists(self): @@ -3305,7 +3331,7 @@ class TosaTestGen: [DType.FP8E5M2, DType.FP8E5M2, DType.FP16], ] - DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) + DEFAULT_RANK_RANGE = (1, gtu.MAX_TENSOR_RANK) TOSA_OP_LIST = { # Tensor operators |