diff options
author | Luke Hutton <luke.hutton@arm.com> | 2023-02-06 14:54:18 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-02-10 20:01:04 +0000 |
commit | 5728713fca4f6e2dff60dad3689e471545e563d2 (patch) | |
tree | 848421100f82a33ff57ee3205c369ad75737f7d3 /reference_model/src/ops | |
parent | c1e25f5755997e65ac1a360ec1e875db06040d8d (diff) | |
download | reference_model-5728713fca4f6e2dff60dad3689e471545e563d2.tar.gz |
Add FFT2d to the reference model
Includes:
* FFT2d reference implementation
* Basic TOSA tests
Change-Id: Ie79fcb713542345d550ec013646810c1e890e388
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 3 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 243 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 23 |
3 files changed, 231 insertions, 38 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index b1a405a..8d84135 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -89,6 +89,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); break; + case Op_FFT2D: + DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); + break; case Op_FULLY_CONNECTED: DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 4663c47..af808e8 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -238,6 +238,86 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute, return 0; } +int check_fft_shape(const std::vector<int32_t>& in_real, + const std::vector<int32_t>& in_imag, + const std::vector<int32_t>& out_real, + const std::vector<int32_t>& out_imag, + std::string& msg) { + const bool is_rfft = in_imag.empty(); + auto is_power_of_two = [](int32_t n) -> bool + { + return (n & (n-1)) == 0 && n > 0; + }; + + if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2])) + { + msg = "Input height and width must be a power of two"; + return 1; + } + + // RFFT does not have a second input + if (!is_rfft) + { + bool input_check = true; + for (size_t i = 0; i < in_real.size(); i++) + { + if (in_real[i] != in_imag[i]) + { + input_check = false; + break; + } + } + if (!input_check) + { + msg = "Mismatch between real input shape and imaginary input shape"; + return 1; + } + } + + bool output_check = true; + for (size_t i = 0; i < out_real.size(); i++) + { + if (out_real[i] != out_imag[i]) + { + output_check = false; + break; + } + } + if (!output_check) + { + msg = "Mismatch between real output shape and imaginary output shape"; + return 1; + } + + if (in_real[0] != out_real[0]) + { + msg = "Input and output batch size don't match"; + return 1; + } + if (in_real[1] != out_real[1]) + { + msg = "Input and output height don't match"; + return 1; + } + + if (is_rfft) + { + if (in_real[2] / 2 + 1 != out_real[2]) + { + msg = "Output width is expected to match input width / 2 + 1"; + return 1; + } + } else { + if (in_real[2] != out_real[2]) + { + msg = "Input and output width don't match"; + return 1; + } + } + + return 0; +} + template <int Rank, DType Dtype> OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -1448,82 +1528,167 @@ int OpMaxPool2d<Dtype>::eval() } template <DType Dtype> -OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) - : GraphNode(sgt_, Op_RFFT2D, id_) +OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : GraphNode(sgt_, Op_FFT2D, id_) { - setRequiredOperands(1, 2); + setRequiredOperands(2, 2); setRequiredRank(3); + + INIT_ATTRIBUTE(FFT); } template <DType Dtype> -OpRFFT2d<Dtype>::~OpRFFT2d() {} +OpFFT2d<Dtype>::~OpFFT2d() { + if (attribute) + delete attribute; +} template <DType Dtype> -int OpRFFT2d<Dtype>::checkTensorAttributes() +int OpFFT2d<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || - validateRequiredRank(outputs[1])) + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || + validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1])) { return 1; } - if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1])) + if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) || + inputs[0]->matchType(*inputs[1])) { - printNodeValidationError("OpRFFT2d: input and output tensor type mismatch"); + printNodeValidationError("OpFFT2d: input and output tensor type mismatch"); return 1; } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]); - ASSERT_MEM(in && out_real && out_imag); + ASSERT_MEM(in_real && in_imag && out_real && out_imag); - auto is_power_of_two = [](int32_t n) -> bool - { - return (n & (n-1)) == 0 && n > 0; - }; - - // Input shape: [N, H, W] - if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2])) + std::string msg; + if (check_fft_shape(in_real->getShape(), in_imag->getShape(), + out_real->getShape(), out_imag->getShape(), msg)) { - printNodeValidationError("OpRFFT2d: input height and width must be a power of two"); + msg = "OpFFT2d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } - // Output shape: [N, H, W / 2 + 1] - bool output_check = true; - for (int32_t i = 0; i < out_real->getRank(); i++) + return 0; +} + +template <DType Dtype> +int OpFFT2d<Dtype>::eval() +{ + int in_real_batch = this->in_real->getShape()[0]; + int in_real_height = this->in_real->getShape()[1]; + int in_real_width = this->in_real->getShape()[2]; + + int in_imag_batch = this->in_imag->getShape()[0]; + int in_imag_height = this->in_imag->getShape()[1]; + int in_imag_width = this->in_imag->getShape()[2]; + + int out_real_batch = this->out_real->getShape()[0]; + int out_real_height = this->out_real->getShape()[1]; + int out_real_width = this->out_real->getShape()[2]; + + int out_imag_batch = this->out_imag->getShape()[0]; + int out_imag_height = this->out_imag->getShape()[1]; + int out_imag_width = this->out_imag->getShape()[2]; + + DEBUG_INFO(OP, + "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]", + in_real_batch, in_real_height, in_real_width, + in_imag_batch, in_imag_height, in_imag_width, + out_real_batch, out_real_height, out_real_width, + out_imag_batch, out_imag_height, out_imag_width); + + OutEigenType sum_real, sum_imag, a, sign_val = 1.0; + + if (attribute->inverse()) { + sign_val = -1.0; + } + + for (int n = 0; n < in_real_batch; n++) { - if (out_real->getShape()[i] != out_imag->getShape()[i]) + for (int oy = 0; oy < out_real_height; oy++) { - output_check = false; - break; + for (int ox = 0; ox < out_real_width; ox++) + { + sum_real = 0.0; + sum_imag = 0.0; + for (int iy = 0; iy < in_real_height; iy++) + { + for (int ix = 0; ix < in_real_width; ix++) + { + OutEigenType val_real = this->in_real->getTensor()(n, iy, ix); + OutEigenType val_imag = this->in_imag->getTensor()(n, iy, ix); + // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType + a = sign_val * 2 * M_PI * ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width); + sum_real += val_real * cos(a) + val_imag * sin(a); + sum_imag += -val_real * sin(a) + val_imag * cos(a); + } + } + this->out_real->getTensor()(n, oy, ox) = sum_real; + this->out_imag->getTensor()(n, oy, ox) = sum_imag; + } } } - if (!output_check) - { - printNodeValidationError( - "OpRFFT2d: Mismatch between real output shape and imaginary output shape"); + + return GraphNode::eval(); +} + +template <DType Dtype> +OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : GraphNode(sgt_, Op_RFFT2D, id_) +{ + setRequiredOperands(1, 2); + setRequiredRank(3); +} + +template <DType Dtype> +OpRFFT2d<Dtype>::~OpRFFT2d() {} + + +template <DType Dtype> +int OpRFFT2d<Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) return 1; - } - if (in->getShape()[0] != out_real->getShape()[0]) { - printNodeValidationError("OpRFFT2d: input and output batch size don't match"); + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || + validateRequiredRank(outputs[1])) + { return 1; } - if (in->getShape()[1] != out_real->getShape()[1]) { - printNodeValidationError("OpRFFT2d: input and output height don't match"); + + if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1])) + { + printNodeValidationError("OpRFFT2d: input and output tensor type mismatch"); return 1; } - if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) { - printNodeValidationError("OpRFFT2d: output width is expected to match input width / 2 + 1"); + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]); + + ASSERT_MEM(in && out_real && out_imag); + + std::string msg; + if (check_fft_shape(in->getShape(), {}, + out_real->getShape(), out_imag->getShape(), msg)) + { + msg = "OpRFFT2d: " + msg; + printNodeValidationError(msg.c_str()); return 1; } @@ -1843,6 +2008,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); + DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 0d2b3eb..9ef4a58 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -249,6 +249,29 @@ protected: }; template <DType Dtype> +class OpFFT2d : public GraphNode +{ +public: + OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpFFT2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, 3>; + using TOut = Eigen::Tensor<OutEigenType, 3>; + +protected: + TosaReference::TensorTemplate<TIn>* in_real; + TosaReference::TensorTemplate<TIn>* in_imag; + TosaReference::TensorTemplate<TOut>* out_real; + TosaReference::TensorTemplate<TOut>* out_imag; + tosa::TosaFFTAttribute* attribute; +}; + +template <DType Dtype> class OpRFFT2d : public GraphNode { public: |