diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/conformance/tosa_main_profile_ops_info.json | 5 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 25 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 22 |
3 files changed, 46 insertions, 6 deletions
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index a53d0c7..b8efd35 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -2702,6 +2702,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "generator_args": [ @@ -2709,13 +2710,13 @@ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0" + "-max,max" ], [ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--target-shape", "1,16,512", "--target-shape", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index f6a46b4..a4bced3 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2821,6 +2821,31 @@ class TosaArgGen: # Return list of tuples: (arg_str, args_dict) return arg_list + @staticmethod + def agRFFT2d(testGen, opName, shapeList, dtype, error_name=None): + arg_list = [] + + shape = shapeList[0] + dot_products = gtu.product(shape) + ks = shape[1] * shape[2] # H*W + args_dict = { + "dot_products": dot_products, + "shape": shape, + "ks": ks, + "acc_type": dtype, + } + arg_list.append(("", args_dict)) + + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) + return arg_list + # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 9c3cd32..d82f919 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -2588,10 +2588,14 @@ class TosaTestGen: def build_rfft2d( self, op, - val, + inputs, + args_dict, validator_fcns=None, error_name=None, + qinfo=None, ): + assert len(inputs) == 1 + val = inputs[0] results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) input_names = [val.name] @@ -2629,7 +2633,14 @@ class TosaTestGen: attr.RFFTAttribute(local_bound) self.ser.addOperator(op["op"], input_names, output_names, attr) - return results + + compliance = [] + for res in results: + compliance.append( + self.tensorComplianceMetaData(op, val.dtype, args_dict, res, error_name) + ) + + return TosaTestGen.BuildInfo(results, compliance) def build_shape_op( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -4781,8 +4792,8 @@ class TosaTestGen: "build_fcn": ( build_rfft2d, TosaTensorGen.tgRFFT2d, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agRFFT2d, ), "types": [DType.FP32], "error_if_validators": ( @@ -4795,6 +4806,9 @@ class TosaTestGen: TosaErrorValidator.evKernelNotPowerOfTwo, TosaErrorValidator.evFFTOutputShapeMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, }, # Shape "add_shape": { |