aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-18 16:57:28 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-02-07 10:57:40 +0000
commitc8330811352f753e36f2ee7be4c7d0e6002f21e7 (patch)
tree967eeb59876e7c6abea26ff2e892d5ff94134992 /verif/generator/tosa_test_gen.py
parent9847722e2b172b69fe9ae80a05c27ca5c8c36617 (diff)
downloadreference_model-c8330811352f753e36f2ee7be4c7d0e6002f21e7.tar.gz
Main Compliance: FFT2D support
Improve access to DOT_PRODUCT generator index and location for debugging. Enable multiple result files for compliance and improve output. Fix up precise and abs modes for FFT2D in ref model to produce correct results and bounds using abs weights. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ide0c9f9f80397e5f1e07ca30a1036d6014b5784d
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py69
1 files changed, 55 insertions, 14 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index bfafd23..68a4e94 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -381,8 +381,35 @@ class TosaTestGen:
"""Enhanced build information containing result tensor and associated compliance dict."""
def __init__(self, resultTensor, complianceDict):
- self.resultTensor = resultTensor
- self.complianceDict = complianceDict
+ if isinstance(resultTensor, list):
+ assert complianceDict is None or isinstance(complianceDict, list)
+ self.resultTensorList = resultTensor
+ self.complianceDictList = complianceDict
+ else:
+ self.resultTensorList = [resultTensor]
+ if complianceDict is None:
+ self.complianceDictList = None
+ else:
+ self.complianceDictList = [complianceDict]
+
+ def getComplianceInfo(self):
+ if self.complianceDictList is None:
+ return None
+ else:
+ tens_dict = {}
+ for tens, comp in zip(self.resultTensorList, self.complianceDictList):
+ if comp is not None:
+ tens_dict[tens.name] = comp
+
+ if tens_dict:
+ # Have some compliance data, so return the info
+ compliance = {
+ "version": "0.1",
+ "tensors": tens_dict,
+ }
+ else:
+ compliance = None
+ return compliance
def build_unary(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -2491,12 +2518,16 @@ class TosaTestGen:
def build_fft2d(
self,
op,
- val1,
- val2,
- inverse,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 2
+ val1, val2 = inputs
+ inverse = args_dict["inverse"]
+
results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
@@ -2537,7 +2568,16 @@ class TosaTestGen:
attr.FFTAttribute(inverse, local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(
+ op, val1.dtype, args_dict, res, error_name
+ )
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_rfft2d(
self,
@@ -2933,13 +2973,11 @@ class TosaTestGen:
if result:
# The test is valid, serialize it
- if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
- # Add the compliance meta data
- # NOTE: This currently expects only one result output
- tensMeta["compliance"] = {
- "version": "0.1",
- "tensors": {result.resultTensor.name: result.complianceDict},
- }
+ if isinstance(result, TosaTestGen.BuildInfo):
+ # Add the compliance meta data (if any)
+ compliance = result.getComplianceInfo()
+ if compliance:
+ tensMeta["compliance"] = compliance
self.serialize("test", tensMeta)
else:
# The test is not valid
@@ -4708,7 +4746,7 @@ class TosaTestGen:
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
@@ -4723,6 +4761,9 @@ class TosaTestGen:
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
"rfft2d": {
"op": Op.RFFT2D,