aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b4939da..f6a46b4 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2798,9 +2798,27 @@ class TosaArgGen:
def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
- arg_list.append(("inverseTrue", [True]))
- arg_list.append(("inverseFalse", [False]))
+ shape = shapeList[0]
+ dot_products = gtu.product(shape)
+ ks = 2 * shape[1] * shape[2] # 2*H*W
+ for inverse in (True, False):
+ args_dict = {
+ "dot_products": dot_products,
+ "shape": shape,
+ "ks": ks,
+ "acc_type": dtype,
+ "inverse": inverse,
+ }
+ arg_list.append((f"inverse{inverse}", args_dict))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
# Helper function for reshape. Gets some factors of a larger number.