aboutsummaryrefslogtreecommitdiff
path: root/verif/generator
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-18 16:57:28 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-02-07 10:57:40 +0000
commitc8330811352f753e36f2ee7be4c7d0e6002f21e7 (patch)
tree967eeb59876e7c6abea26ff2e892d5ff94134992 /verif/generator
parent9847722e2b172b69fe9ae80a05c27ca5c8c36617 (diff)
downloadreference_model-c8330811352f753e36f2ee7be4c7d0e6002f21e7.tar.gz
Main Compliance: FFT2D support
Improve access to DOT_PRODUCT generator index and location for debugging. Enable multiple result files for compliance and improve output. Fix up precise and abs modes for FFT2D in ref model to produce correct results and bounds using abs weights. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ide0c9f9f80397e5f1e07ca30a1036d6014b5784d
Diffstat (limited to 'verif/generator')
-rw-r--r--verif/generator/tosa_arg_gen.py22
-rw-r--r--verif/generator/tosa_test_gen.py69
2 files changed, 75 insertions, 16 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b4939da..f6a46b4 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2798,9 +2798,27 @@ class TosaArgGen:
def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
- arg_list.append(("inverseTrue", [True]))
- arg_list.append(("inverseFalse", [False]))
+ shape = shapeList[0]
+ dot_products = gtu.product(shape)
+ ks = 2 * shape[1] * shape[2] # 2*H*W
+ for inverse in (True, False):
+ args_dict = {
+ "dot_products": dot_products,
+ "shape": shape,
+ "ks": ks,
+ "acc_type": dtype,
+ "inverse": inverse,
+ }
+ arg_list.append((f"inverse{inverse}", args_dict))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
# Helper function for reshape. Gets some factors of a larger number.
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index bfafd23..68a4e94 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -381,8 +381,35 @@ class TosaTestGen:
"""Enhanced build information containing result tensor and associated compliance dict."""
def __init__(self, resultTensor, complianceDict):
- self.resultTensor = resultTensor
- self.complianceDict = complianceDict
+ if isinstance(resultTensor, list):
+ assert complianceDict is None or isinstance(complianceDict, list)
+ self.resultTensorList = resultTensor
+ self.complianceDictList = complianceDict
+ else:
+ self.resultTensorList = [resultTensor]
+ if complianceDict is None:
+ self.complianceDictList = None
+ else:
+ self.complianceDictList = [complianceDict]
+
+ def getComplianceInfo(self):
+ if self.complianceDictList is None:
+ return None
+ else:
+ tens_dict = {}
+ for tens, comp in zip(self.resultTensorList, self.complianceDictList):
+ if comp is not None:
+ tens_dict[tens.name] = comp
+
+ if tens_dict:
+ # Have some compliance data, so return the info
+ compliance = {
+ "version": "0.1",
+ "tensors": tens_dict,
+ }
+ else:
+ compliance = None
+ return compliance
def build_unary(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -2491,12 +2518,16 @@ class TosaTestGen:
def build_fft2d(
self,
op,
- val1,
- val2,
- inverse,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 2
+ val1, val2 = inputs
+ inverse = args_dict["inverse"]
+
results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
input_names = [val1.name, val2.name]
@@ -2537,7 +2568,16 @@ class TosaTestGen:
attr.FFTAttribute(inverse, local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(
+ op, val1.dtype, args_dict, res, error_name
+ )
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_rfft2d(
self,
@@ -2933,13 +2973,11 @@ class TosaTestGen:
if result:
# The test is valid, serialize it
- if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
- # Add the compliance meta data
- # NOTE: This currently expects only one result output
- tensMeta["compliance"] = {
- "version": "0.1",
- "tensors": {result.resultTensor.name: result.complianceDict},
- }
+ if isinstance(result, TosaTestGen.BuildInfo):
+ # Add the compliance meta data (if any)
+ compliance = result.getComplianceInfo()
+ if compliance:
+ tensMeta["compliance"] = compliance
self.serialize("test", tensMeta)
else:
# The test is not valid
@@ -4708,7 +4746,7 @@ class TosaTestGen:
"build_fcn": (
build_fft2d,
TosaTensorGen.tgFFT2d,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agFFT2d,
),
"types": [DType.FP32],
@@ -4723,6 +4761,9 @@ class TosaTestGen:
TosaErrorValidator.evFFTInputShapeMismatch,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
"rfft2d": {
"op": Op.RFFT2D,