aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-01-10 14:50:31 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-01-24 13:40:17 +0000
commit261b7b62b959a6c7312d810d9152069fdff69f3e (patch)
tree2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /verif/frameworks
parentc253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff)
downloadreference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz
Add RFFT2d to the reference model
Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'verif/frameworks')
-rw-r--r--verif/frameworks/arg_gen.py30
-rw-r--r--verif/frameworks/tensor_gen.py9
-rw-r--r--verif/frameworks/test_builder.py8
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py16
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py10
5 files changed, 70 insertions, 3 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index d81c3dd..61a1de0 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import math
+
import numpy as np
@@ -851,3 +853,29 @@ class ArgGen:
else:
axes.append(["_axis_m{}".format(-i), [i]])
return axes
+
+ def agRFFT2d(op, shape, rng):
+ args = []
+
+ # Must be rank 3 input tensor
+ if len(shape) != 3:
+ return []
+
+ # Check rfft2d with enforced fft_length
+ for fft_length_h in [2, 32]:
+ for fft_length_w in [2, 8, 16]:
+ fft_length = [fft_length_h, fft_length_w]
+ args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
+
+ # Check rfft2d with no fft_length provided (fft_length=None).
+ # In this case, the height and width of the input should be
+ # used for the calculation. Therefore, we need to check that
+ # the input shape is already a power of two.
+ def is_power_of_two(x):
+ return math.log(x, 2).is_integer()
+
+ height, width = shape[1:3]
+ if is_power_of_two(height) and is_power_of_two(width):
+ args.append(["_fft_length_None", [None]])
+
+ return args
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 767989e..c534a58 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -274,3 +274,12 @@ class TGen:
)
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRFFT2d(op, shape, dtype, rng):
+ # Require rank 3 shape
+ if len(shape) != 3:
+ return [], []
+
+ tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
+ return tf_placeholders, []
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 8870f41..6e7b6a5 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1243,3 +1243,11 @@ class TBuilder:
def eval(self, a):
return self.dense(a)
+
+ class RFFT2d:
+ def __init__(self, fft_length, name):
+ self.fft_length = fft_length
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index 3597f2a..c55864a 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import glob
@@ -483,6 +483,20 @@ def run_test(args, test, framework):
except KeyError:
assert 0, "fail to load tflite result numpy"
+ # TOSA has no notion of complex datatypes, it represents complex values using two
+ # fp32 output tensors representing real and imaginary values. When legalizing
+ # complex operations from frameworks, these two output tensors are combined into
+ # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
+ # represents the real and imaginary parts of a complex value. This is completed
+ # by inserting reshape and concatenate TOSA operations during the legalization to
+ # maintain a one-to-one correspondance with framework outputs, thus simplifying
+ # legalization. Here tf_result should also match this format before being
+ # compared to the ref model output.
+ if tf_result.dtype == np.complex64:
+ ifm_shape = tf_result.shape + (2,)
+ tf_result = tf_result.view(np.float32)
+ tf_result = tf_result.reshape(ifm_shape)
+
# Generate test descriptor per flatbuffer generation
# Input .npy will be shared across different frameworks
# Output .npy will be generated in its corresponding flatbuffer
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 5b8856d..36ddda5 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
@@ -839,6 +839,13 @@ TF_OP_LIST = {
]
},
},
+ "rfft2d": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
+ "types": {
+ "tflite": TYPE_F,
+ },
+ },
}
# Shapes to be tested; default can be overwritten
@@ -847,6 +854,7 @@ shape_list = [
(64,),
(14, 19),
(13, 21, 3),
+ (1, 8, 16),
(1, 4, 4, 4),
(1, 8, 4, 17),
(1, 4, 8, 19),