diff options
Diffstat (limited to 'verif/frameworks/arg_gen.py')
-rw-r--r-- | verif/frameworks/arg_gen.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index d81c3dd..61a1de0 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -1,5 +1,7 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np @@ -851,3 +853,29 @@ class ArgGen: else: axes.append(["_axis_m{}".format(-i), [i]]) return axes + + def agRFFT2d(op, shape, rng): + args = [] + + # Must be rank 3 input tensor + if len(shape) != 3: + return [] + + # Check rfft2d with enforced fft_length + for fft_length_h in [2, 32]: + for fft_length_w in [2, 8, 16]: + fft_length = [fft_length_h, fft_length_w] + args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]]) + + # Check rfft2d with no fft_length provided (fft_length=None). + # In this case, the height and width of the input should be + # used for the calculation. Therefore, we need to check that + # the input shape is already a power of two. + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + height, width = shape[1:3] + if is_power_of_two(height) and is_power_of_two(width): + args.append(["_fft_length_None", [None]]) + + return args |