From 5728713fca4f6e2dff60dad3689e471545e563d2 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 6 Feb 2023 14:54:18 +0000 Subject: Add FFT2d to the reference model Includes: * FFT2d reference implementation * Basic TOSA tests Change-Id: Ie79fcb713542345d550ec013646810c1e890e388 Signed-off-by: Luke Hutton --- reference_model/src/ops/op_factory.cc | 3 + reference_model/src/ops/tensor_ops.cc | 243 ++++++++++++++++++++++++++++------ reference_model/src/ops/tensor_ops.h | 23 ++++ verif/generator/tosa_arg_gen.py | 48 +++++++ verif/generator/tosa_error_if.py | 64 ++++++++- verif/generator/tosa_test_gen.py | 111 +++++++++++++++- 6 files changed, 450 insertions(+), 42 deletions(-) diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index b1a405a..8d84135 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -89,6 +89,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); break; + case Op_FFT2D: + DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); + break; case Op_FULLY_CONNECTED: DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 4663c47..af808e8 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -238,6 +238,86 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute, return 0; } +int check_fft_shape(const std::vector& in_real, + const std::vector& in_imag, + const std::vector& out_real, + const std::vector& out_imag, + std::string& msg) { + const bool is_rfft = in_imag.empty(); + auto is_power_of_two = [](int32_t n) -> bool + { + return (n & (n-1)) == 0 && n > 0; + }; + + if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2])) + { + msg = "Input height and width must be a power of two"; + return 1; + } + + // RFFT does not have a second input + if (!is_rfft) + { + bool input_check = true; + for (size_t i = 0; i < in_real.size(); i++) + { + if (in_real[i] != in_imag[i]) + { + input_check = false; + break; + } + } + if (!input_check) + { + msg = "Mismatch between real input shape and imaginary input shape"; + return 1; + } + } + + bool output_check = true; + for (size_t i = 0; i < out_real.size(); i++) + { + if (out_real[i] != out_imag[i]) + { + output_check = false; + break; + } + } + if (!output_check) + { + msg = "Mismatch between real output shape and imaginary output shape"; + return 1; + } + + if (in_real[0] != out_real[0]) + { + msg = "Input and output batch size don't match"; + return 1; + } + if (in_real[1] != out_real[1]) + { + msg = "Input and output height don't match"; + return 1; + } + + if (is_rfft) + { + if (in_real[2] / 2 + 1 != out_real[2]) + { + msg = "Output width is expected to match input width / 2 + 1"; + return 1; + } + } else { + if (in_real[2] != out_real[2]) + { + msg = "Input and output width don't match"; + return 1; + } + } + + return 0; +} + template OpArgMax::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -1448,82 +1528,167 @@ int OpMaxPool2d::eval() } template -OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) - : GraphNode(sgt_, Op_RFFT2D, id_) +OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : GraphNode(sgt_, Op_FFT2D, id_) { - setRequiredOperands(1, 2); + setRequiredOperands(2, 2); setRequiredRank(3); + + INIT_ATTRIBUTE(FFT); } template -OpRFFT2d::~OpRFFT2d() {} +OpFFT2d::~OpFFT2d() { + if (attribute) + delete attribute; +} template -int OpRFFT2d::checkTensorAttributes() +int OpFFT2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || - validateRequiredRank(outputs[1])) + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || + validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1])) { return 1; } - if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1])) + if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || + inputs[0]->matchType(*inputs[1])) { - printNodeValidationError("OpRFFT2d: input and output tensor type mismatch"); + printNodeValidationError("OpFFT2d: input and output tensor type mismatch"); return 1; } - in = dynamic_cast*>(inputs[0]); + in_real = dynamic_cast*>(inputs[0]); + in_imag = dynamic_cast*>(inputs[1]); out_real = dynamic_cast*>(outputs[0]); out_imag = dynamic_cast*>(outputs[1]); - ASSERT_MEM(in && out_real && out_imag); + ASSERT_MEM(in_real && in_imag && out_real && out_imag); - auto is_power_of_two = [](int32_t n) -> bool - { - return (n & (n-1)) == 0 && n > 0; - }; - - // Input shape: [N, H, W] - if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2])) + std::string msg; + if (check_fft_shape(in_real->getShape(), in_imag->getShape(), + out_real->getShape(), out_imag->getShape(), msg)) { - printNodeValidationError("OpRFFT2d: input height and width must be a power of two"); + msg = "OpFFT2d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } - // Output shape: [N, H, W / 2 + 1] - bool output_check = true; - for (int32_t i = 0; i < out_real->getRank(); i++) + return 0; +} + +template +int OpFFT2d::eval() +{ + int in_real_batch = this->in_real->getShape()[0]; + int in_real_height = this->in_real->getShape()[1]; + int in_real_width = this->in_real->getShape()[2]; + + int in_imag_batch = this->in_imag->getShape()[0]; + int in_imag_height = this->in_imag->getShape()[1]; + int in_imag_width = this->in_imag->getShape()[2]; + + int out_real_batch = this->out_real->getShape()[0]; + int out_real_height = this->out_real->getShape()[1]; + int out_real_width = this->out_real->getShape()[2]; + + int out_imag_batch = this->out_imag->getShape()[0]; + int out_imag_height = this->out_imag->getShape()[1]; + int out_imag_width = this->out_imag->getShape()[2]; + + DEBUG_INFO(OP, + "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]", + in_real_batch, in_real_height, in_real_width, + in_imag_batch, in_imag_height, in_imag_width, + out_real_batch, out_real_height, out_real_width, + out_imag_batch, out_imag_height, out_imag_width); + + OutEigenType sum_real, sum_imag, a, sign_val = 1.0; + + if (attribute->inverse()) { + sign_val = -1.0; + } + + for (int n = 0; n < in_real_batch; n++) { - if (out_real->getShape()[i] != out_imag->getShape()[i]) + for (int oy = 0; oy < out_real_height; oy++) { - output_check = false; - break; + for (int ox = 0; ox < out_real_width; ox++) + { + sum_real = 0.0; + sum_imag = 0.0; + for (int iy = 0; iy < in_real_height; iy++) + { + for (int ix = 0; ix < in_real_width; ix++) + { + OutEigenType val_real = this->in_real->getTensor()(n, iy, ix); + OutEigenType val_imag = this->in_imag->getTensor()(n, iy, ix); + // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType + a = sign_val * 2 * M_PI * ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width); + sum_real += val_real * cos(a) + val_imag * sin(a); + sum_imag += -val_real * sin(a) + val_imag * cos(a); + } + } + this->out_real->getTensor()(n, oy, ox) = sum_real; + this->out_imag->getTensor()(n, oy, ox) = sum_imag; + } } } - if (!output_check) - { - printNodeValidationError( - "OpRFFT2d: Mismatch between real output shape and imaginary output shape"); + + return GraphNode::eval(); +} + +template +OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : GraphNode(sgt_, Op_RFFT2D, id_) +{ + setRequiredOperands(1, 2); + setRequiredRank(3); +} + +template +OpRFFT2d::~OpRFFT2d() {} + + +template +int OpRFFT2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) return 1; - } - if (in->getShape()[0] != out_real->getShape()[0]) { - printNodeValidationError("OpRFFT2d: input and output batch size don't match"); + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || + validateRequiredRank(outputs[1])) + { return 1; } - if (in->getShape()[1] != out_real->getShape()[1]) { - printNodeValidationError("OpRFFT2d: input and output height don't match"); + + if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1])) + { + printNodeValidationError("OpRFFT2d: input and output tensor type mismatch"); return 1; } - if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) { - printNodeValidationError("OpRFFT2d: output width is expected to match input width / 2 + 1"); + + in = dynamic_cast*>(inputs[0]); + out_real = dynamic_cast*>(outputs[0]); + out_imag = dynamic_cast*>(outputs[1]); + + ASSERT_MEM(in && out_real && out_imag); + + std::string msg; + if (check_fft_shape(in->getShape(), {}, + out_real->getShape(), out_imag->getShape(), msg)) + { + msg = "OpRFFT2d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } @@ -1843,6 +2008,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); + DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 0d2b3eb..9ef4a58 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -248,6 +248,29 @@ protected: tosa::TosaPoolAttribute* attribute; }; +template +class OpFFT2d : public GraphNode +{ +public: + OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpFFT2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TOut = Eigen::Tensor; + +protected: + TosaReference::TensorTemplate* in_real; + TosaReference::TensorTemplate* in_imag; + TosaReference::TensorTemplate* out_real; + TosaReference::TensorTemplate* out_imag; + tosa::TosaFFTAttribute* attribute; +}; + template class OpRFFT2d : public GraphNode { diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 05a7d2b..370570c 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -416,6 +416,45 @@ class TosaTensorGen: return [ifm_shape, filter_shape, bias_shape] + @staticmethod + def tgFFT2d(testGen, op, rank, error_name=None): + pl, const = op["operands"] + + if error_name != ErrorIf.WrongRank: + assert rank == 3 + assert pl == 2 and const == 0 + + # IFM dimensions are NHW + ifm_shape = testGen.makeShape(rank) + + # Select nearest lower power of two from input height and width + ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2)) + ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2)) + + # Constrict the overall size of the shape when creating ERROR_IF tests + if error_name: + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape) + + # Generate an invalid kernel that is not a power of two + if error_name == ErrorIf.KernelNotPowerOfTwo: + inc_h = 2 if ifm_shape[1] == 1 else 1 + inc_w = 2 if ifm_shape[2] == 1 else 1 + inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)] + selected_inc = testGen.rng.choice(inc_choices) + ifm_shape[1] += selected_inc[0] + ifm_shape[2] += selected_inc[1] + + ifm_shape = testGen.constrictBatchSize(ifm_shape) + + ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()] + if error_name == ErrorIf.FFTInputShapeMismatch: + modify_shape = testGen.rng.choice([0, 1]) + # Only modify kernel (H, W) + modify_dim = testGen.rng.choice([1, 2]) + ifm_shapes[modify_shape][modify_dim] *= 2 + + return [ifm_shapes[0], ifm_shapes[1]] + @staticmethod def tgRFFT2d(testGen, op, rank, error_name=None): pl, const = op["operands"] @@ -1613,6 +1652,15 @@ class TosaArgGen: return arg_list + @staticmethod + def agFFT2d(testGen, opName, shapeList, dtype, error_name=None): + arg_list = [] + + arg_list.append(("inverseTrue", [True])) + arg_list.append(("inverseFalse", [False])) + + return arg_list + # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 93f975d..ee227b3 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -79,6 +79,8 @@ class ErrorIf(object): CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne" CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne" KernelNotPowerOfTwo = "KernelNotPowerOfTwo" + FFTInputShapeMismatch = "FFTInputShapeMismatch" + FFTOutputShapeMismatch = "FFTOutputShapeMismatch" class TosaErrorIfArgGen: @@ -562,7 +564,7 @@ class TosaErrorValidator: ): error_result = True - elif op["op"] == Op.RFFT2D: + elif op["op"] in [Op.FFT2D, Op.RFFT2D]: if not all([ty == input_dtype for ty in output_dtype]): error_result = True @@ -686,7 +688,7 @@ class TosaErrorValidator: op = kwargs["op"] output_list = kwargs["output_list"] expected_length = 1 - if op["op"] == Op.RFFT2D: + if op["op"] in [Op.FFT2D, Op.RFFT2D]: expected_length = 2 if len(output_list) != expected_length: @@ -2446,6 +2448,64 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def evFFTInputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.FFTInputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Mismatch between real and imaginary input shapes" + + if check: + input1 = kwargs["input1"] + input2 = kwargs["input2"] + + if input1.shape != input2.shape: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + + @staticmethod + def evFFTOutputShapeMismatch(check=False, **kwargs): + error_name = ErrorIf.FFTOutputShapeMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = ( + "Mismatch between provided and expected output kernel (H, W) shape" + ) + + if check: + op = kwargs["op"] + input_shape = kwargs["input_shape"] + + if len(input_shape) == 3: + output_shapes = kwargs["output_shape"] + + # Ignoring batch size (N) from input shape + expected_shape = input_shape[1:] + if op["op"] == Op.RFFT2D: + expected_shape[1] = expected_shape[1] // 2 + 1 + + # Ignoring batch size (N) from output shapes + output_shape_0 = output_shapes[0][1:] + output_shape_1 = output_shapes[1][1:] + # Ensure sure the kernel sizes (H, W) of both outputs match the expected + if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape: + error_result = True + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 5f9e2c1..2b762aa 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -213,6 +213,12 @@ class TosaTestGen: else: raise Exception(f"Unknown dtype, cannot determine width: {dtype}") + def constrictBatchSize(self, shape): + # Limit the batch size unless an explicit target shape set + if self.args.max_batch_size and not self.args.target_shapes: + shape[0] = min(shape[0], self.args.max_batch_size) + return shape + # Argument generators # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) # Where the string descriptor is used to generate the test name and @@ -2081,6 +2087,48 @@ class TosaTestGen: return acc_out + def build_fft2d( + self, op, val1, val2, inverse, validator_fcns=None, error_name=None + ): + results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name) + + input_names = [val1.name, val2.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + + output_names = [res.name for res in results] + output_shapes = [res.shape for res in results] + output_dtypes = [res.dtype for res in results] + + input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_names, output_names + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + inverse=inverse, + input1=val1, + input2=val2, + input_shape=val1.shape, + input_dtype=val1.dtype, + output_shape=output_shapes, + output_dtype=output_dtypes, + result_tensors=results, + input_list=input_names, + output_list=output_names, + num_operands=num_operands, + ): + return None + + attr = ts.TosaSerializerAttribute() + attr.FFTAttribute(inverse) + + self.ser.addOperator(op["op"], input_names, output_names, attr) + return results + def build_rfft2d(self, op, val, validator_fcns=None, error_name=None): results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) @@ -2089,6 +2137,7 @@ class TosaTestGen: num_operands = pCount + cCount output_names = [res.name for res in results] + output_shapes = [res.shape for res in results] output_dtypes = [res.dtype for res in results] input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -2102,6 +2151,7 @@ class TosaTestGen: op=op, input_shape=val.shape, input_dtype=val.dtype, + output_shape=output_shapes, output_dtype=output_dtypes, result_tensors=results, input_list=input_names, @@ -3927,6 +3977,29 @@ class TosaTestGen: TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, + "fft2d": { + "op": Op.FFT2D, + "operands": (2, 0), + "rank": (3, 3), + "build_fcn": ( + build_fft2d, + TosaTensorGen.tgFFT2d, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agFFT2d, + ), + "types": [DType.FP32], + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evKernelNotPowerOfTwo, + TosaErrorValidator.evFFTInputShapeMismatch, + TosaErrorValidator.evFFTOutputShapeMismatch, + ), + }, "rfft2d": { "op": Op.RFFT2D, "operands": (1, 0), @@ -3946,6 +4019,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evKernelNotPowerOfTwo, + TosaErrorValidator.evFFTOutputShapeMismatch, ), }, } @@ -4769,6 +4843,37 @@ class OutputShaper: return ser.addOutput(output_shape, out_dtype) + @staticmethod + def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None): + outputs = [] + + assert ifm1.dtype == ifm2.dtype + input_dtype = ifm1.dtype + + if error_name != ErrorIf.FFTInputShapeMismatch: + assert ifm1.shape == ifm2.shape + + input_shape = ifm1.shape + if error_name != ErrorIf.WrongRank: + assert len(input_shape) == 3 + + output_shape = input_shape.copy() + output_dtype = input_dtype + + if error_name == ErrorIf.WrongOutputType: + excludes = [DType.FP32] + wrong_dtypes = list(usableDTypes(excludes=excludes)) + output_dtype = rng.choice(wrong_dtypes) + elif error_name == ErrorIf.BatchMismatch: + output_shape[0] += rng.integers(1, 10) + elif error_name == ErrorIf.FFTOutputShapeMismatch: + modify_dim = rng.choice([1, 2]) + output_shape[modify_dim] += rng.integers(1, 10) + + outputs.append(serializer.addOutput(output_shape, output_dtype)) + outputs.append(serializer.addOutput(output_shape, output_dtype)) + return outputs + @staticmethod def rfft2dOp(serializer, rng, value, error_name=None): outputs = [] @@ -4785,8 +4890,10 @@ class OutputShaper: wrong_dtypes = list(usableDTypes(excludes=excludes)) output_dtype = rng.choice(wrong_dtypes) elif error_name == ErrorIf.BatchMismatch: - incorrect_batch = input_shape[0] + rng.integers(1, 10) - output_shape = [incorrect_batch, *input_shape[1:]] + output_shape[0] += rng.integers(1, 10) + elif error_name == ErrorIf.FFTOutputShapeMismatch: + modify_dim = rng.choice([1, 2]) + output_shape[modify_dim] += rng.integers(1, 10) outputs.append(serializer.addOutput(output_shape, output_dtype)) outputs.append(serializer.addOutput(output_shape, output_dtype)) -- cgit v1.2.1