aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/arg_gen.py
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-01-10 14:50:31 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-01-24 13:40:17 +0000
commit261b7b62b959a6c7312d810d9152069fdff69f3e (patch)
tree2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /verif/frameworks/arg_gen.py
parentc253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff)
downloadreference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz
Add RFFT2d to the reference model
Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'verif/frameworks/arg_gen.py')
-rw-r--r--verif/frameworks/arg_gen.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index d81c3dd..61a1de0 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import math
+
import numpy as np
@@ -851,3 +853,29 @@ class ArgGen:
else:
axes.append(["_axis_m{}".format(-i), [i]])
return axes
+
+ def agRFFT2d(op, shape, rng):
+ args = []
+
+ # Must be rank 3 input tensor
+ if len(shape) != 3:
+ return []
+
+ # Check rfft2d with enforced fft_length
+ for fft_length_h in [2, 32]:
+ for fft_length_w in [2, 8, 16]:
+ fft_length = [fft_length_h, fft_length_w]
+ args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
+
+ # Check rfft2d with no fft_length provided (fft_length=None).
+ # In this case, the height and width of the input should be
+ # used for the calculation. Therefore, we need to check that
+ # the input shape is already a power of two.
+ def is_power_of_two(x):
+ return math.log(x, 2).is_integer()
+
+ height, width = shape[1:3]
+ if is_power_of_two(height) and is_power_of_two(width):
+ args.append(["_fft_length_None", [None]])
+
+ return args