aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 9c3cd32..d82f919 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2588,10 +2588,14 @@ class TosaTestGen:
def build_rfft2d(
self,
op,
- val,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
+ qinfo=None,
):
+ assert len(inputs) == 1
+ val = inputs[0]
results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
input_names = [val.name]
@@ -2629,7 +2633,14 @@ class TosaTestGen:
attr.RFFTAttribute(local_bound)
self.ser.addOperator(op["op"], input_names, output_names, attr)
- return results
+
+ compliance = []
+ for res in results:
+ compliance.append(
+ self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name)
+ )
+
+ return TosaTestGen.BuildInfo(results, compliance)
def build_shape_op(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -4781,8 +4792,8 @@ class TosaTestGen:
"build_fcn": (
build_rfft2d,
TosaTensorGen.tgRFFT2d,
- TosaTensorValuesGen.tvgDefault,
- None,
+ TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaArgGen.agRFFT2d,
),
"types": [DType.FP32],
"error_if_validators": (
@@ -4795,6 +4806,9 @@ class TosaTestGen:
TosaErrorValidator.evKernelNotPowerOfTwo,
TosaErrorValidator.evFFTOutputShapeMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
# Shape
"add_shape": {