aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-30 16:10:50 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-02-08 11:12:55 +0000
commit6f57e6e665094959aed40c0e388ac81fbd118720 (patch)
tree82fdfa4b40baf370aa346e3d19fa3f1760294ee9 /verif/generator/tosa_test_gen.py
parent47ab1762d1c15a7b4c0c068d7294111c5c5f92a2 (diff)
downloadreference_model-6f57e6e665094959aed40c0e388ac81fbd118720.tar.gz
Main Compliance: RFFT2D support
Correct ref model to produce imaginery values of zero as specification indicates at certain output positions. Fix up precise and abs modes for RFFT2D in ref model to produce correct results and bounds using abs weights. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I33767e4219a260278f7933f28b1799223a95a3cc
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 9c3cd32..d82f919 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2588,10 +2588,14 @@ class TosaTestGen:
def build_rfft2d(
self,
op,
- val,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 1
+ val = inputs[0]
results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
input_names = [val.name]
@@ -2629,7 +2633,14 @@ class TosaTestGen:
attr.RFFTAttribute(local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_shape_op(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -4781,8 +4792,8 @@ class TosaTestGen:
"build_fcn": (
build_rfft2d,
TosaTensorGen.tgRFFT2d,
- TosaTensorValuesGen.tvgDefault,
- None,
+ TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaArgGen.agRFFT2d,
),
"types": [DType.FP32],
"error_if_validators": (
@@ -4795,6 +4806,9 @@ class TosaTestGen:
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
# Shape
"add_shape": {