aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py43
1 files changed, 29 insertions, 14 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 8beb2ae..8fcea29 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -9,6 +9,7 @@ from pathlib import Path
import generator.tosa_utils as gtu
import numpy as np
import serializer.tosa_serializer as ts
+from generator.datagenerator import GenerateLibrary
from generator.tosa_arg_gen import TosaArgGen
from generator.tosa_arg_gen import TosaQuantGen
from generator.tosa_arg_gen import TosaTensorGen
@@ -55,6 +56,11 @@ class TosaTestGen:
self.random_fp_high = max(args.tensor_fp_value_range)
# JSON schema validation
self.descSchemaValidator = TestDescSchemaValidator()
+ # Data generator library when not generating the data later
+ if not args.lazy_data_gen:
+ self.dgl = GenerateLibrary(args.generate_lib_path)
+ else:
+ self.dgl = None
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
@@ -92,15 +98,21 @@ class TosaTestGen:
self.descSchemaValidator.validate_config(desc)
if metaData:
- if self.args.lazy_data_gen and "data_gen" in metaData:
- # Output datagen meta data as CPP data
- path_md = path / f"{testName}_meta_data_gen.cpp"
- with path_md.open("w") as fd:
- fd.write(TOSA_AUTOGENERATED_HEADER)
- fd.write("// Test meta data for data generation setup\n\n")
- fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
- json.dump(metaData["data_gen"], fd)
- fd.write(')";\n\n')
+ if "data_gen" in metaData:
+ if self.args.lazy_data_gen:
+ # Output datagen meta data as CPP data
+ path_md = path / f"{testName}_meta_data_gen.cpp"
+ with path_md.open("w") as fd:
+ fd.write(TOSA_AUTOGENERATED_HEADER)
+ fd.write("// Test meta data for data generation setup\n\n")
+ fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
+ json.dump(metaData["data_gen"], fd)
+ fd.write(')";\n\n')
+ else:
+ # Generate the data
+ self.dgl.set_config(desc)
+ self.dgl.write_numpy_files(path)
+
if "compliance" in metaData:
# Output datagen meta data as CPP data
path_md = path / f"{testName}_meta_compliance.cpp"
@@ -282,8 +294,8 @@ class TosaTestGen:
)
def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
- if errorName or not gtu.dtypeIsFloat(outputTensor.dtype):
- # No compliance for error tests or integer tests currently
+ if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype):
+ # No compliance for error tests or other data types currently
return None
# Create compliance meta data for expected output tensor
@@ -1099,9 +1111,12 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
+ if gtu.dtypeIsSupportedByCompliance(a.dtype):
+ compliance = self.tensorComplianceMetaData(
+ op, args_dict, result_tensor, error_name
+ )
+ else:
+ compliance = None
return TosaTestGen.BuildInfo(result_tensor, compliance)