aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-01-10 14:50:31 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-01-24 13:40:17 +0000
commit261b7b62b959a6c7312d810d9152069fdff69f3e (patch)
tree2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /verif
parentc253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff)
downloadreference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz
Add RFFT2d to the reference model
Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'verif')
-rw-r--r--verif/frameworks/arg_gen.py30
-rw-r--r--verif/frameworks/tensor_gen.py9
-rw-r--r--verif/frameworks/test_builder.py8
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py16
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py10
-rw-r--r--verif/generator/tosa_arg_gen.py37
-rw-r--r--verif/generator/tosa_error_if.py103
-rw-r--r--verif/generator/tosa_test_gen.py130
8 files changed, 282 insertions, 61 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index d81c3dd..61a1de0 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import math
+
import numpy as np
@@ -851,3 +853,29 @@ class ArgGen:
else:
axes.append(["_axis_m{}".format(-i), [i]])
return axes
+
+ def agRFFT2d(op, shape, rng):
+ args = []
+
+ # Must be rank 3 input tensor
+ if len(shape) != 3:
+ return []
+
+ # Check rfft2d with enforced fft_length
+ for fft_length_h in [2, 32]:
+ for fft_length_w in [2, 8, 16]:
+ fft_length = [fft_length_h, fft_length_w]
+ args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]])
+
+ # Check rfft2d with no fft_length provided (fft_length=None).
+ # In this case, the height and width of the input should be
+ # used for the calculation. Therefore, we need to check that
+ # the input shape is already a power of two.
+ def is_power_of_two(x):
+ return math.log(x, 2).is_integer()
+
+ height, width = shape[1:3]
+ if is_power_of_two(height) and is_power_of_two(width):
+ args.append(["_fft_length_None", [None]])
+
+ return args
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 767989e..c534a58 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -274,3 +274,12 @@ class TGen:
)
return tf_placeholders, tf_consts
+
+ @staticmethod
+ def tgRFFT2d(op, shape, dtype, rng):
+ # Require rank 3 shape
+ if len(shape) != 3:
+ return [], []
+
+ tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))]
+ return tf_placeholders, []
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index 8870f41..6e7b6a5 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1243,3 +1243,11 @@ class TBuilder:
def eval(self, a):
return self.dense(a)
+
+ class RFFT2d:
+ def __init__(self, fft_length, name):
+ self.fft_length = fft_length
+ self.result_name = name
+
+ def eval(self, a):
+ return tf.signal.rfft2d(a, self.fft_length, name=self.result_name)
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index 3597f2a..c55864a 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import glob
@@ -483,6 +483,20 @@ def run_test(args, test, framework):
except KeyError:
assert 0, "fail to load tflite result numpy"
+ # TOSA has no notion of complex datatypes, it represents complex values using two
+ # fp32 output tensors representing real and imaginary values. When legalizing
+ # complex operations from frameworks, these two output tensors are combined into
+ # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values
+ # represents the real and imaginary parts of a complex value. This is completed
+ # by inserting reshape and concatenate TOSA operations during the legalization to
+ # maintain a one-to-one correspondance with framework outputs, thus simplifying
+ # legalization. Here tf_result should also match this format before being
+ # compared to the ref model output.
+ if tf_result.dtype == np.complex64:
+ ifm_shape = tf_result.shape + (2,)
+ tf_result = tf_result.view(np.float32)
+ tf_result = tf_result.reshape(ifm_shape)
+
# Generate test descriptor per flatbuffer generation
# Input .npy will be shared across different frameworks
# Output .npy will be generated in its corresponding flatbuffer
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 5b8856d..36ddda5 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
@@ -839,6 +839,13 @@ TF_OP_LIST = {
]
},
},
+ "rfft2d": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
+ "types": {
+ "tflite": TYPE_F,
+ },
+ },
}
# Shapes to be tested; default can be overwritten
@@ -847,6 +854,7 @@ shape_list = [
(64,),
(14, 19),
(13, 21, 3),
+ (1, 8, 16),
(1, 4, 4, 4),
(1, 8, 4, 17),
(1, 4, 8, 19),
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 4e15b06..fed91f6 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import itertools
import math
@@ -417,6 +417,41 @@ class TosaTensorGen:
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
+ def tgRFFT2d(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+ assert pl == 1 and const == 0
+
+ # IFM dimensions are NHW
+ ifm_shape = testGen.makeShape(rank)
+
+ # Select nearest lower power of two from input height and width
+ ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
+ ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
+
+ # Generate an invalid kernel that is not a power of two
+ if error_name == ErrorIf.KernelNotPowerOfTwo:
+ # We must increment by 2 if current size is 1
+ inc_h = 2 if ifm_shape[1] == 1 else 1
+ inc_w = 2 if ifm_shape[2] == 1 else 1
+ inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
+ selected_inc = testGen.rng.choice(inc_choices)
+ ifm_shape[1] += selected_inc[0]
+ ifm_shape[2] += selected_inc[1]
+
+ # Constrict the batch size
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ return [ifm_shape]
+
+ @staticmethod
def tgFullyConnected(testGen, op, rank, error_name=None):
pl, const = op["operands"]
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index c9d35c7..40c5d13 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,5 +1,7 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import math
+
import numpy as np
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import product
@@ -76,6 +78,7 @@ class ErrorIf(object):
CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
+ KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
class TosaErrorIfArgGen:
@@ -548,6 +551,10 @@ class TosaErrorValidator:
):
error_result = True
+ elif op["op"] == Op.RFFT2D:
+ if not all([ty == input_dtype for ty in output_dtype]):
+ error_result = True
+
elif op["op"] in {
Op.CONV2D,
Op.CONV3D,
@@ -665,9 +672,13 @@ class TosaErrorValidator:
error_reason = "Op output list does not match expected output"
if check:
+ op = kwargs["op"]
output_list = kwargs["output_list"]
- # Note this will be incorrect if an operator returns more than one output
- if len(output_list) != 1:
+ expected_length = 1
+ if op["op"] == Op.RFFT2D:
+ expected_length = 2
+
+ if len(output_list) != expected_length:
error_result = True
info_dict = {
@@ -711,7 +722,7 @@ class TosaErrorValidator:
@staticmethod
def evBatchMismatch(check=False, **kwargs):
error_name = ErrorIf.BatchMismatch
- param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
error_reason = "Input batch size not equal to output batch size"
@@ -722,12 +733,15 @@ class TosaErrorValidator:
if check:
input_shape = kwargs["input_shape"]
- output_shape = kwargs[
- "result_tensor"
- ].shape # Note this is just (N, OH, OW, C)
- if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
- error_result = True
+ for output in kwargs["result_tensors"]:
+ output_shape = (
+ output.shape
+ ) # Note batch is expected to be the first dim
+ if (len(input_shape) in rank_range) and (
+ input_shape[0] != output_shape[0]
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -751,11 +765,12 @@ class TosaErrorValidator:
if check:
input_shape = kwargs["input_shape"]
- output_shape = kwargs[
- "result_tensor"
- ].shape # Note this is just (N, OH, OW, C)
- if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
- error_result = True
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape # Note this is just (N, OH, OW, C)
+ if (len(input_shape) in rank_range) and (
+ input_shape[3] != output_shape[3]
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1044,13 +1059,15 @@ class TosaErrorValidator:
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- output_shape = kwargs["result_tensor"].shape
- if (
- (len(input1_shape) != len(output_shape))
- or (len(input2_shape) != len(output_shape))
- or (len(input3_shape) != len(output_shape))
- ):
- error_result = True
+
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape
+ if (
+ (len(input1_shape) != len(output_shape))
+ or (len(input2_shape) != len(output_shape))
+ or (len(input3_shape) != len(output_shape))
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1074,16 +1091,18 @@ class TosaErrorValidator:
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- output_shape = kwargs["result_tensor"].shape
- for i in range(
- min(len(input1_shape), len(input2_shape), len(input3_shape))
- ):
- if (
- (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
- or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
- or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape
+ for i in range(
+ min(len(input1_shape), len(input2_shape), len(input3_shape))
):
- error_result = True
+ if (
+ (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
+ or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
+ or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -2392,6 +2411,30 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evKernelNotPowerOfTwo(check=False, **kwargs):
+ error_name = ErrorIf.KernelNotPowerOfTwo
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "kernel height and/or width not a power of two"
+
+ def is_power_of_two(x):
+ return math.log(x, 2).is_integer()
+
+ if check:
+ shape = kwargs["input_shape"]
+ if len(shape) == 3:
+ valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2])
+ error_result = not valid_kernel
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c29763b..fddf942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -255,7 +255,7 @@ class TosaTestGen:
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -293,7 +293,7 @@ class TosaTestGen:
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -333,7 +333,7 @@ class TosaTestGen:
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -378,7 +378,7 @@ class TosaTestGen:
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -414,7 +414,7 @@ class TosaTestGen:
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -448,7 +448,7 @@ class TosaTestGen:
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -487,7 +487,7 @@ class TosaTestGen:
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -523,7 +523,7 @@ class TosaTestGen:
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -582,7 +582,7 @@ class TosaTestGen:
stride=stride,
pad=pad,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -938,7 +938,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -980,7 +980,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1016,7 +1016,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1064,7 +1064,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1122,7 +1122,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1153,7 +1153,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1199,7 +1199,7 @@ class TosaTestGen:
input_dtype=a[0].dtype,
output_dtype=result_tens.dtype,
inputs=a,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1250,7 +1250,7 @@ class TosaTestGen:
output_dtype=result_tens.dtype,
pad=padding,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1283,7 +1283,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1318,7 +1318,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1356,7 +1356,7 @@ class TosaTestGen:
perms=perms,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1391,7 +1391,7 @@ class TosaTestGen:
output_dtype=result_tens.dtype,
start=start,
size=size,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1425,7 +1425,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1474,7 +1474,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=values.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1519,7 +1519,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=values_in.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1580,7 +1580,7 @@ class TosaTestGen:
border=border,
input_list=input_list,
output_list=output_list,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
num_operands=num_operands,
):
return None
@@ -1628,7 +1628,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=val.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1774,7 +1774,7 @@ class TosaTestGen:
double_round=double_round,
input_list=input_list,
output_list=output_list,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
num_operands=num_operands,
):
return None
@@ -2083,6 +2083,38 @@ class TosaTestGen:
return acc_out
+ def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
+ results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
+
+ input_names = [val.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+
+ output_names = [res.name for res in results]
+ output_dtypes = [res.dtype for res in results]
+
+ input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_names, output_names
+ )
+
+ if not TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape=val.shape,
+ input_dtype=val.dtype,
+ output_dtype=output_dtypes,
+ result_tensors=results,
+ input_list=input_names,
+ output_list=output_names,
+ num_operands=num_operands,
+ ):
+ return None
+
+ self.ser.addOperator(op["op"], input_names, output_names)
+ return results
+
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
@@ -3897,6 +3929,27 @@ class TosaTestGen:
TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
+ "rfft2d": {
+ "op": Op.RFFT2D,
+ "operands": (1, 0),
+ "rank": (3, 3),
+ "build_fcn": (
+ build_rfft2d,
+ TosaTensorGen.tgRFFT2d,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.FP32],
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evBatchMismatch,
+ TosaErrorValidator.evKernelNotPowerOfTwo,
+ ),
+ },
}
@@ -4717,3 +4770,26 @@ class OutputShaper:
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(output_shape, out_dtype)
+
+ @staticmethod
+ def rfft2dOp(serializer, rng, value, error_name=None):
+ outputs = []
+
+ input_shape = value.shape
+ if error_name != ErrorIf.WrongRank:
+ assert len(input_shape) == 3
+
+ output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
+
+ output_dtype = value.dtype
+ if error_name == ErrorIf.WrongOutputType:
+ excludes = [DType.FP32]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
+ output_dtype = rng.choice(wrong_dtypes)
+ elif error_name == ErrorIf.BatchMismatch:
+ incorrect_batch = input_shape[0] + rng.integers(1, 10)
+ output_shape = [incorrect_batch, *input_shape[1:]]
+
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ return outputs