From 261b7b62b959a6c7312d810d9152069fdff69f3e Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 10 Jan 2023 14:50:31 +0000 Subject: Add RFFT2d to the reference model Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e --- reference_model/src/ops/op_factory.cc | 5 +- reference_model/src/ops/tensor_ops.cc | 138 ++++++++++++++++++++- reference_model/src/ops/tensor_ops.h | 23 +++- thirdparty/serialization_lib | 2 +- verif/frameworks/arg_gen.py | 30 ++++- verif/frameworks/tensor_gen.py | 9 ++ verif/frameworks/test_builder.py | 8 ++ .../tosa_verif_framework_compiler_runner.py | 16 ++- verif/frameworks/tosa_verif_framework_generator.py | 10 +- verif/generator/tosa_arg_gen.py | 37 +++++- verif/generator/tosa_error_if.py | 103 ++++++++++----- verif/generator/tosa_test_gen.py | 130 +++++++++++++++---- 12 files changed, 446 insertions(+), 65 deletions(-) diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 0121ccf..0d56161 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -113,6 +113,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); break; + case Op_RFFT2D: + DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); + break; case Op_TRANSPOSE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b9ac94a..dff9e08 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -1453,6 +1453,140 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } +template +OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : GraphNode(sgt_, Op_RFFT2D, id_) +{ + setRequiredOperands(1, 2); + setRequiredRank(3); +} + +template +OpRFFT2d::~OpRFFT2d() {} + + +template +int OpRFFT2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || + validateRequiredRank(outputs[1])) + { + return 1; + } + + if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1])) + { + printNodeValidationError("OpRFFT2d: input and output tensor type mismatch"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out_real = dynamic_cast*>(outputs[0]); + out_imag = dynamic_cast*>(outputs[1]); + + ASSERT_MEM(in && out_real && out_imag); + + auto is_power_of_two = [](int32_t n) -> bool + { + return (n & (n-1)) == 0 && n > 0; + }; + + // Input shape: [N, H, W] + if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2])) + { + printNodeValidationError("OpRFFT2d: input height and width must be a power of two"); + return 1; + } + + // Output shape: [N, H, W / 2 + 1] + bool output_check = true; + for (int32_t i = 0; i < out_real->getRank(); i++) + { + if (out_real->getShape()[i] != out_imag->getShape()[i]) + { + output_check = false; + break; + } + } + if (!output_check) + { + printNodeValidationError( + "OpRFFT2d: Mismatch between real output shape and imaginary output shape"); + return 1; + } + + if (in->getShape()[0] != out_real->getShape()[0]) { + printNodeValidationError("OpRFFT2d: input and output batch size don't match"); + return 1; + } + if (in->getShape()[1] != out_real->getShape()[1]) { + printNodeValidationError("OpRFFT2d: input and output height don't match"); + return 1; + } + if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) { + printNodeValidationError("OpRFFT2d: output width is expected to match input width / 2 + 1"); + return 1; + } + + return 0; +} + +template +int OpRFFT2d::eval() +{ + int32_t in_batch = in->getShape()[0]; + int32_t in_height = in->getShape()[1]; + int32_t in_width = in->getShape()[2]; + + int32_t out_real_batch = out_real->getShape()[0]; + int32_t out_real_height = out_real->getShape()[1]; + int32_t out_real_width = out_real->getShape()[2]; + + int32_t out_imag_batch = out_imag->getShape()[0]; + int32_t out_imag_height = out_imag->getShape()[1]; + int32_t out_imag_width = out_imag->getShape()[2]; + + DEBUG_INFO(OP, + "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], " + "output_imag.shape=[%d,%d,%d]", + in_batch, in_height, in_width, + out_real_batch, out_real_height, out_real_width, + out_imag_batch, out_imag_height, out_imag_width); + + OutEigenType sum_real, sum_imag, a; + + for (int n = 0; n < in_batch; n++) + { + for (int oy = 0; oy < out_real_height; oy++) + { + for (int ox = 0; ox < out_real_width; ox++) + { + sum_real = 0.0; + sum_imag = 0.0; + for (int iy = 0; iy < in_height; iy++) + { + for (int ix = 0; ix < in_width; ix++) + { + // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType + a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width); + sum_real += this->in->getTensor()(n, iy, ix) * cos(a); + sum_imag += -this->in->getTensor()(n, iy, ix) * sin(a); + } + } + this->out_real->getTensor()(n, oy, ox) = sum_real; + this->out_imag->getTensor()(n, oy, ox) = sum_imag; + } + } + } + + return GraphNode::eval(); +} + template OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -1738,6 +1872,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); + DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index fd6dd25..ed9a55c 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -248,6 +248,27 @@ protected: tosa::TosaPoolAttribute* attribute; }; +template +class OpRFFT2d : public GraphNode +{ +public: + OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpRFFT2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaReference::TensorTemplate* in; + TosaReference::TensorTemplate* out_real; + TosaReference::TensorTemplate* out_imag; +}; + template class OpTransposeConv2d : public GraphNode { diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index e36f4f7..c15f7d5 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit e36f4f70b51c03712db96ea284e6e54b3e60a74c +Subproject commit c15f7d52aa4f360eba2344449baa418b7608ac7c diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index d81c3dd..61a1de0 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -1,5 +1,7 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np @@ -851,3 +853,29 @@ class ArgGen: else: axes.append(["_axis_m{}".format(-i), [i]]) return axes + + def agRFFT2d(op, shape, rng): + args = [] + + # Must be rank 3 input tensor + if len(shape) != 3: + return [] + + # Check rfft2d with enforced fft_length + for fft_length_h in [2, 32]: + for fft_length_w in [2, 8, 16]: + fft_length = [fft_length_h, fft_length_w] + args.append(["_fft_length_{}x{}".format(*fft_length), [fft_length]]) + + # Check rfft2d with no fft_length provided (fft_length=None). + # In this case, the height and width of the input should be + # used for the calculation. Therefore, we need to check that + # the input shape is already a power of two. + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + height, width = shape[1:3] + if is_power_of_two(height) and is_power_of_two(width): + args.append(["_fft_length_None", [None]]) + + return args diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 767989e..c534a58 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -274,3 +274,12 @@ class TGen: ) return tf_placeholders, tf_consts + + @staticmethod + def tgRFFT2d(op, shape, dtype, rng): + # Require rank 3 shape + if len(shape) != 3: + return [], [] + + tf_placeholders = [("placeholder_0", TGen.getRand(shape, dtype, rng))] + return tf_placeholders, [] diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index 8870f41..6e7b6a5 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1243,3 +1243,11 @@ class TBuilder: def eval(self, a): return self.dense(a) + + class RFFT2d: + def __init__(self, fft_length, name): + self.fft_length = fft_length + self.result_name = name + + def eval(self, a): + return tf.signal.rfft2d(a, self.fft_length, name=self.result_name) diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py index 3597f2a..c55864a 100755 --- a/verif/frameworks/tosa_verif_framework_compiler_runner.py +++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import glob @@ -483,6 +483,20 @@ def run_test(args, test, framework): except KeyError: assert 0, "fail to load tflite result numpy" + # TOSA has no notion of complex datatypes, it represents complex values using two + # fp32 output tensors representing real and imaginary values. When legalizing + # complex operations from frameworks, these two output tensors are combined into + # a single tensor of shape [?, ..., ?, 2] whereby each inner pair of values + # represents the real and imaginary parts of a complex value. This is completed + # by inserting reshape and concatenate TOSA operations during the legalization to + # maintain a one-to-one correspondance with framework outputs, thus simplifying + # legalization. Here tf_result should also match this format before being + # compared to the ref model output. + if tf_result.dtype == np.complex64: + ifm_shape = tf_result.shape + (2,) + tf_result = tf_result.view(np.float32) + tf_result = tf_result.reshape(ifm_shape) + # Generate test descriptor per flatbuffer generation # Input .npy will be shared across different frameworks # Output .npy will be generated in its corresponding flatbuffer diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index 5b8856d..36ddda5 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse import os @@ -839,6 +839,13 @@ TF_OP_LIST = { ] }, }, + "rfft2d": { + "operands": (1, 0), + "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), + "types": { + "tflite": TYPE_F, + }, + }, } # Shapes to be tested; default can be overwritten @@ -847,6 +854,7 @@ shape_list = [ (64,), (14, 19), (13, 21, 3), + (1, 8, 16), (1, 4, 4, 4), (1, 8, 4, 17), (1, 4, 8, 19), diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 4e15b06..fed91f6 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools import math @@ -416,6 +416,41 @@ class TosaTensorGen: return [ifm_shape, filter_shape, bias_shape] + @staticmethod + def tgRFFT2d(testGen, op, rank, error_name=None): + pl, const = op["operands"] + + if error_name != ErrorIf.WrongRank: + assert rank == 3 + assert pl == 1 and const == 0 + + # IFM dimensions are NHW + ifm_shape = testGen.makeShape(rank) + + # Select nearest lower power of two from input height and width + ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) + ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) + + # Constrict the overall size of the shape when creating ERROR_IF tests + if error_name: + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) + + # Generate an invalid kernel that is not a power of two + if error_name == ErrorIf.KernelNotPowerOfTwo: + # We must increment by 2 if current size is 1 + inc_h = 2 if ifm_shape[1] == 1 else 1 + inc_w = 2 if ifm_shape[2] == 1 else 1 + inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] + selected_inc = testGen.rng.choice(inc_choices) + ifm_shape[1] += selected_inc[0] + ifm_shape[2] += selected_inc[1] + + # Constrict the batch size + if testGen.args.max_batch_size: + ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 + + return [ifm_shape] + @staticmethod def tgFullyConnected(testGen, op, rank, error_name=None): pl, const = op["operands"] diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index c9d35c7..40c5d13 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1,5 +1,7 @@ -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import math + import numpy as np from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import product @@ -76,6 +78,7 @@ class ErrorIf(object): CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool" CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" + KernelNotPowerOfTwo = "KernelNotPowerOfTwo" class TosaErrorIfArgGen: @@ -548,6 +551,10 @@ class TosaErrorValidator: ): error_result = True + elif op["op"] == Op.RFFT2D: + if not all([ty == input_dtype for ty in output_dtype]): + error_result = True + elif op["op"] in { Op.CONV2D, Op.CONV3D, @@ -665,9 +672,13 @@ class TosaErrorValidator: error_reason = "Op output list does not match expected output" if check: + op = kwargs["op"] output_list = kwargs["output_list"] - # Note this will be incorrect if an operator returns more than one output - if len(output_list) != 1: + expected_length = 1 + if op["op"] == Op.RFFT2D: + expected_length = 2 + + if len(output_list) != expected_length: error_result = True info_dict = { @@ -711,7 +722,7 @@ class TosaErrorValidator: @staticmethod def evBatchMismatch(check=False, **kwargs): error_name = ErrorIf.BatchMismatch - param_reqs = {"rank": [4, 4], "dtype": None, "shape": None} + param_reqs = {"rank": None, "dtype": None, "shape": None} error_result = False error_reason = "Input batch size not equal to output batch size" @@ -722,12 +733,15 @@ class TosaErrorValidator: if check: input_shape = kwargs["input_shape"] - output_shape = kwargs[ - "result_tensor" - ].shape # Note this is just (N, OH, OW, C) - if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]): - error_result = True + for output in kwargs["result_tensors"]: + output_shape = ( + output.shape + ) # Note batch is expected to be the first dim + if (len(input_shape) in rank_range) and ( + input_shape[0] != output_shape[0] + ): + error_result = True info_dict = { "error_name": error_name, @@ -751,11 +765,12 @@ class TosaErrorValidator: if check: input_shape = kwargs["input_shape"] - output_shape = kwargs[ - "result_tensor" - ].shape # Note this is just (N, OH, OW, C) - if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]): - error_result = True + for output in kwargs["result_tensors"]: + output_shape = output.shape # Note this is just (N, OH, OW, C) + if (len(input_shape) in rank_range) and ( + input_shape[3] != output_shape[3] + ): + error_result = True info_dict = { "error_name": error_name, @@ -1044,13 +1059,15 @@ class TosaErrorValidator: input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - output_shape = kwargs["result_tensor"].shape - if ( - (len(input1_shape) != len(output_shape)) - or (len(input2_shape) != len(output_shape)) - or (len(input3_shape) != len(output_shape)) - ): - error_result = True + + for output in kwargs["result_tensors"]: + output_shape = output.shape + if ( + (len(input1_shape) != len(output_shape)) + or (len(input2_shape) != len(output_shape)) + or (len(input3_shape) != len(output_shape)) + ): + error_result = True info_dict = { "error_name": error_name, @@ -1074,16 +1091,18 @@ class TosaErrorValidator: input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - output_shape = kwargs["result_tensor"].shape - for i in range( - min(len(input1_shape), len(input2_shape), len(input3_shape)) - ): - if ( - (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) - or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) - or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + + for output in kwargs["result_tensors"]: + output_shape = output.shape + for i in range( + min(len(input1_shape), len(input2_shape), len(input3_shape)) ): - error_result = True + if ( + (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) + or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) + or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + ): + error_result = True info_dict = { "error_name": error_name, @@ -2392,6 +2411,30 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evKernelNotPowerOfTwo(check=False, **kwargs): + error_name = ErrorIf.KernelNotPowerOfTwo + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "kernel height and/or width not a power of two" + + def is_power_of_two(x): + return math.log(x, 2).is_integer() + + if check: + shape = kwargs["input_shape"] + if len(shape) == 3: + valid_kernel = is_power_of_two(shape[1]) and is_power_of_two(shape[2]) + error_result = not valid_kernel + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c29763b..fddf942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -255,7 +255,7 @@ class TosaTestGen: input_dtype=a.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -293,7 +293,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -333,7 +333,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -378,7 +378,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -414,7 +414,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -448,7 +448,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -487,7 +487,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -523,7 +523,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -582,7 +582,7 @@ class TosaTestGen: stride=stride, pad=pad, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -938,7 +938,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -980,7 +980,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1016,7 +1016,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1064,7 +1064,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1122,7 +1122,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1153,7 +1153,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1199,7 +1199,7 @@ class TosaTestGen: input_dtype=a[0].dtype, output_dtype=result_tens.dtype, inputs=a, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1250,7 +1250,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, pad=padding, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1283,7 +1283,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1318,7 +1318,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1356,7 +1356,7 @@ class TosaTestGen: perms=perms, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1391,7 +1391,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, start=start, size=size, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1425,7 +1425,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1474,7 +1474,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1519,7 +1519,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values_in.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1580,7 +1580,7 @@ class TosaTestGen: border=border, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -1628,7 +1628,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=val.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1774,7 +1774,7 @@ class TosaTestGen: double_round=double_round, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -2083,6 +2083,38 @@ class TosaTestGen: return acc_out + def build_rfft2d(self, op, val, validator_fcns=None, error_name=None): + results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) + + input_names = [val.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + + output_names = [res.name for res in results] + output_dtypes = [res.dtype for res in results] + + input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_names, output_names + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape=val.shape, + input_dtype=val.dtype, + output_dtype=output_dtypes, + result_tensors=results, + input_list=input_names, + output_list=output_names, + num_operands=num_operands, + ): + return None + + self.ser.addOperator(op["op"], input_names, output_names) + return results + def create_filter_lists( self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None ): @@ -3897,6 +3929,27 @@ class TosaTestGen: TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, + "rfft2d": { + "op": Op.RFFT2D, + "operands": (1, 0), + "rank": (3, 3), + "build_fcn": ( + build_rfft2d, + TosaTensorGen.tgRFFT2d, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agNone, + ), + "types": [DType.FP32], + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evKernelNotPowerOfTwo, + ), + }, } @@ -4717,3 +4770,26 @@ class OutputShaper: out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) + + @staticmethod + def rfft2dOp(serializer, rng, value, error_name=None): + outputs = [] + + input_shape = value.shape + if error_name != ErrorIf.WrongRank: + assert len(input_shape) == 3 + + output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1] + + output_dtype = value.dtype + if error_name == ErrorIf.WrongOutputType: + excludes = [DType.FP32] + wrong_dtypes = list(usableDTypes(excludes=excludes)) + output_dtype = rng.choice(wrong_dtypes) + elif error_name == ErrorIf.BatchMismatch: + incorrect_batch = input_shape[0] + rng.integers(1, 10) + output_shape = [incorrect_batch, *input_shape[1:]] + + outputs.append(serializer.addOutput(output_shape, output_dtype)) + outputs.append(serializer.addOutput(output_shape, output_dtype)) + return outputs -- cgit v1.2.1