diff options
Diffstat (limited to 'verif/frameworks')
-rw-r--r-- | verif/frameworks/arg_gen.py | 30 | ||||
-rw-r--r-- | verif/frameworks/tensor_gen.py | 9 | ||||
-rw-r--r-- | verif/frameworks/test_builder.py | 8 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_compiler_runner.py | 16 | ||||
-rwxr-xr-x | verif/frameworks/tosa_verif_framework_generator.py | 10 |
5 files changed, 70 insertions, 3 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index d81c3dd..61a1de0 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -1,5 +1,7 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np @@ -851,3 +853,29 @@ class ArgGen: else: axes.append(["_axis_m{}".format(-i), [i]]) return axes + + def agRFFT2d(op, shape, rng): + args = [] + + # Must be rank 3 input tensor + if len(shape) != 3: + return [] + + # Check rfft2d with enforced fft_length + for fft_length_h in [2, 32]: + for fft_length_w in [2, 8, 16]: + fft_length = [fft_length_h, fft_length_w] + args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]]) + + # Check rfft2d with no fft_length provided (fft_length=None). + # In this case, the height and width of the input should be + # used for the calculation. Therefore, we need to check that + # the input shape is already a power of two. + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + height, width = shape[1:3] + if is_power_of_two(height) and is_power_of_two(width): + args.append(["_fft_length_None", [None]]) + + return args diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 767989e..c534a58 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -274,3 +274,12 @@ class TGen: ) return tf_placeholders, tf_consts + + @staticmethod + def tgRFFT2d(op, shape, dtype, rng): + # Require rank 3 shape + if len(shape) != 3: + return [], [] + + tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))] + return tf_placeholders, [] diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 8870f41..6e7b6a5 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1243,3 +1243,11 @@ class TBuilder: def eval(self, a): return self.dense(a) + + class RFFT2d: + def __init__(self, fft_length, name): + self.fft_length = fft_length + self.result_name = name + + def eval(self, a): + return tf.signal.rfft2d(a, self.fft_length, name=self.result_name) diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py index 3597f2a..c55864a 100755 --- a/verif/frameworks/tosa_verif_framework_compiler_runner.py +++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import glob @@ -483,6 +483,20 @@ def run_test(args, test, framework): except KeyError: assert 0, "fail to load tflite result numpy" + # TOSA has no notion of complex datatypes, it represents complex values using two + # fp32 output tensors representing real and imaginary values. When legalizing + # complex operations from frameworks, these two output tensors are combined into + # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values + # represents the real and imaginary parts of a complex value. This is completed + # by inserting reshape and concatenate TOSA operations during the legalization to + # maintain a one-to-one correspondance with framework outputs, thus simplifying + # legalization. Here tf_result should also match this format before being + # compared to the ref model output. + if tf_result.dtype == np.complex64: + ifm_shape = tf_result.shape + (2,) + tf_result = tf_result.view(np.float32) + tf_result = tf_result.reshape(ifm_shape) + # Generate test descriptor per flatbuffer generation # Input .npy will be shared across different frameworks # Output .npy will be generated in its corresponding flatbuffer diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 5b8856d..36ddda5 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import os @@ -839,6 +839,13 @@ TF_OP_LIST = { ] }, }, + "rfft2d": { + "operands": (1, 0), + "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), + "types": { + "tflite": TYPE_F, + }, + }, } # Shapes to be tested; default can be overwritten @@ -847,6 +854,7 @@ shape_list = [ (64,), (14, 19), (13, 21, 3), + (1, 8, 16), (1, 4, 4, 4), (1, 8, 4, 17), (1, 4, 8, 19), |