aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 4e15b06..fed91f6 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import itertools
import math
@@ -417,6 +417,41 @@ class TosaTensorGen:
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
+ def tgRFFT2d(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+ assert pl == 1 and const == 0
+
+ # IFM dimensions are NHW
+ ifm_shape = testGen.makeShape(rank)
+
+ # Select nearest lower power of two from input height and width
+ ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
+ ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
+
+ # Generate an invalid kernel that is not a power of two
+ if error_name == ErrorIf.KernelNotPowerOfTwo:
+ # We must increment by 2 if current size is 1
+ inc_h = 2 if ifm_shape[1] == 1 else 1
+ inc_w = 2 if ifm_shape[2] == 1 else 1
+ inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
+ selected_inc = testGen.rng.choice(inc_choices)
+ ifm_shape[1] += selected_inc[0]
+ ifm_shape[2] += selected_inc[1]
+
+ # Constrict the batch size
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ return [ifm_shape]
+
+ @staticmethod
def tgFullyConnected(testGen, op, rank, error_name=None):
pl, const = op["operands"]