aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 05a7d2b..370570c 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -417,6 +417,45 @@ class TosaTensorGen:
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
+ def tgFFT2d(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+ assert pl == 2 and const == 0
+
+ # IFM dimensions are NHW
+ ifm_shape = testGen.makeShape(rank)
+
+ # Select nearest lower power of two from input height and width
+ ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
+ ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
+
+ # Generate an invalid kernel that is not a power of two
+ if error_name == ErrorIf.KernelNotPowerOfTwo:
+ inc_h = 2 if ifm_shape[1] == 1 else 1
+ inc_w = 2 if ifm_shape[2] == 1 else 1
+ inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
+ selected_inc = testGen.rng.choice(inc_choices)
+ ifm_shape[1] += selected_inc[0]
+ ifm_shape[2] += selected_inc[1]
+
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
+
+ ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
+ if error_name == ErrorIf.FFTInputShapeMismatch:
+ modify_shape = testGen.rng.choice([0, 1])
+ # Only modify kernel (H, W)
+ modify_dim = testGen.rng.choice([1, 2])
+ ifm_shapes[modify_shape][modify_dim] *= 2
+
+ return [ifm_shapes[0], ifm_shapes[1]]
+
+ @staticmethod
def tgRFFT2d(testGen, op, rank, error_name=None):
pl, const = op["operands"]
@@ -1613,6 +1652,15 @@ class TosaArgGen:
return arg_list
+ @staticmethod
+ def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ arg_list.append(("inverseTrue", [True]))
+ arg_list.append(("inverseFalse", [False]))
+
+ return arg_list
+
# Helper function for reshape. Gets some factors of a larger number.
@staticmethod
def getFactors(val, start=1):