aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py69
1 files changed, 55 insertions, 14 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index bfafd23..68a4e94 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -381,8 +381,35 @@ class TosaTestGen:
"""Enhanced build information containing result tensor and associated compliance dict."""
def __init__(self, resultTensor, complianceDict):
- self.resultTensor = resultTensor
- self.complianceDict = complianceDict
+ if isinstance(resultTensor, list):
+ assert complianceDict is None or isinstance(complianceDict, list)
+ self.resultTensorList = resultTensor
+ self.complianceDictList = complianceDict
+ else:
+ self.resultTensorList = [resultTensor]
+ if complianceDict is None:
+ self.complianceDictList = None
+ else:
+ self.complianceDictList = [complianceDict]
+
+ def getComplianceInfo(self):
+ if self.complianceDictList is None:
+ return None
+ else:
+ tens_dict = {}
+ for tens, comp in zip(self.resultTensorList, self.complianceDictList):
+ if comp is not None:
+ tens_dict[tens.name] = comp
+
+ if tens_dict:
+ # Have some compliance data, so return the info
+ compliance = {
+ "version": "0.1",
+ "tensors": tens_dict,
+ }
+ else:
+ compliance = None
+ return compliance
def build_unary(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -2491,12 +2518,16 @@ class TosaTestGen:
def build_fft2d(
self,
op,
- val1,
- val2,
- inverse,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 2
+ val1, val2 = inputs
+ inverse = args_dict["inverse"]
+
results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
@@ -2537,7 +2568,16 @@ class TosaTestGen:
attr.FFTAttribute(inverse, local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(
+ op, val1.dtype, args_dict, res, error_name
+ )
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_rfft2d(
self,
@@ -2933,13 +2973,11 @@ class TosaTestGen:
if result:
# The test is valid, serialize it
- if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
- # Add the compliance meta data
- # NOTE: This currently expects only one result output
- tensMeta["compliance"] = {
- "version": "0.1",
- "tensors": {result.resultTensor.name: result.complianceDict},
- }
+ if isinstance(result, TosaTestGen.BuildInfo):
+ # Add the compliance meta data (if any)
+ compliance = result.getComplianceInfo()
+ if compliance:
+ tensMeta["compliance"] = compliance
self.serialize("test", tensMeta)
else:
# The test is not valid
@@ -4708,7 +4746,7 @@ class TosaTestGen:
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
@@ -4723,6 +4761,9 @@ class TosaTestGen:
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
"rfft2d": {
"op": Op.RFFT2D,