diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 43 |
1 files changed, 29 insertions, 14 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 8beb2ae..8fcea29 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -9,6 +9,7 @@ from pathlib import Path import generator.tosa_utils as gtu import numpy as np import serializer.tosa_serializer as ts +from generator.datagenerator import GenerateLibrary from generator.tosa_arg_gen import TosaArgGen from generator.tosa_arg_gen import TosaQuantGen from generator.tosa_arg_gen import TosaTensorGen @@ -55,6 +56,11 @@ class TosaTestGen: self.random_fp_high = max(args.tensor_fp_value_range) # JSON schema validation self.descSchemaValidator = TestDescSchemaValidator() + # Data generator library when not generating the data later + if not args.lazy_data_gen: + self.dgl = GenerateLibrary(args.generate_lib_path) + else: + self.dgl = None def createSerializer(self, opName, testPath): self.testPath = os.path.join(opName, testPath) @@ -92,15 +98,21 @@ class TosaTestGen: self.descSchemaValidator.validate_config(desc) if metaData: - if self.args.lazy_data_gen and "data_gen" in metaData: - # Output datagen meta data as CPP data - path_md = path / f"{testName}_meta_data_gen.cpp" - with path_md.open("w") as fd: - fd.write(TOSA_AUTOGENERATED_HEADER) - fd.write("// Test meta data for data generation setup\n\n") - fd.write(f'const char* json_tdg_config_{path.stem} = R"(') - json.dump(metaData["data_gen"], fd) - fd.write(')";\n\n') + if "data_gen" in metaData: + if self.args.lazy_data_gen: + # Output datagen meta data as CPP data + path_md = path / f"{testName}_meta_data_gen.cpp" + with path_md.open("w") as fd: + fd.write(TOSA_AUTOGENERATED_HEADER) + fd.write("// Test meta data for data generation setup\n\n") + fd.write(f'const char* json_tdg_config_{path.stem} = R"(') + json.dump(metaData["data_gen"], fd) + fd.write(')";\n\n') + else: + # Generate the data + self.dgl.set_config(desc) + self.dgl.write_numpy_files(path) + if "compliance" in metaData: # Output datagen meta data as CPP data path_md = path / f"{testName}_meta_compliance.cpp" @@ -282,8 +294,8 @@ class TosaTestGen: ) def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName): - if errorName or not gtu.dtypeIsFloat(outputTensor.dtype): - # No compliance for error tests or integer tests currently + if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype): + # No compliance for error tests or other data types currently return None # Create compliance meta data for expected output tensor @@ -1099,9 +1111,12 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - compliance = self.tensorComplianceMetaData( - op, args_dict, result_tensor, error_name - ) + if gtu.dtypeIsSupportedByCompliance(a.dtype): + compliance = self.tensorComplianceMetaData( + op, args_dict, result_tensor, error_name + ) + else: + compliance = None return TosaTestGen.BuildInfo(result_tensor, compliance) |