diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 05a7d2b..370570c 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -417,6 +417,45 @@ class TosaTensorGen: return [ifm_shape, filter_shape, bias_shape] @staticmethod + def tgFFT2d(testGen, op, rank, error_name=None): + pl, const = op["operands"] + + if error_name != ErrorIf.WrongRank: + assert rank == 3 + assert pl == 2 and const == 0 + + # IFM dimensions are NHW + ifm_shape = testGen.makeShape(rank) + + # Select nearest lower power of two from input height and width + ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) + ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) + + # Constrict the overall size of the shape when creating ERROR_IF tests + if error_name: + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) + + # Generate an invalid kernel that is not a power of two + if error_name == ErrorIf.KernelNotPowerOfTwo: + inc_h = 2 if ifm_shape[1] == 1 else 1 + inc_w = 2 if ifm_shape[2] == 1 else 1 + inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] + selected_inc = testGen.rng.choice(inc_choices) + ifm_shape[1] += selected_inc[0] + ifm_shape[2] += selected_inc[1] + + ifm_shape = testGen.constrictBatchSize(ifm_shape) + + ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()] + if error_name == ErrorIf.FFTInputShapeMismatch: + modify_shape = testGen.rng.choice([0, 1]) + # Only modify kernel (H, W) + modify_dim = testGen.rng.choice([1, 2]) + ifm_shapes[modify_shape][modify_dim] *= 2 + + return [ifm_shapes[0], ifm_shapes[1]] + + @staticmethod def tgRFFT2d(testGen, op, rank, error_name=None): pl, const = op["operands"] @@ -1613,6 +1652,15 @@ class TosaArgGen: return arg_list + @staticmethod + def agFFT2d(testGen, opName, shapeList, dtype, error_name=None): + arg_list = [] + + arg_list.append(("inverseTrue", [True])) + arg_list.append(("inverseFalse", [False])) + + return arg_list + # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): |