From 261b7b62b959a6c7312d810d9152069fdff69f3e Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 10 Jan 2023 14:50:31 +0000 Subject: Add RFFT2d to the reference model Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e --- reference_model/src/ops/tensor_ops.cc | 138 +++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) (limited to 'reference_model/src/ops/tensor_ops.cc') diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b9ac94a..dff9e08 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -1453,6 +1453,140 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } +template +OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) + : GraphNode(sgt_, Op_RFFT2D, id_) +{ + setRequiredOperands(1, 2); + setRequiredRank(3); +} + +template +OpRFFT2d::~OpRFFT2d() {} + + +template +int OpRFFT2d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) || + validateRequiredRank(outputs[1])) + { + return 1; + } + + if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1])) + { + printNodeValidationError("OpRFFT2d: input and output tensor type mismatch"); + return 1; + } + + in = dynamic_cast*>(inputs[0]); + out_real = dynamic_cast*>(outputs[0]); + out_imag = dynamic_cast*>(outputs[1]); + + ASSERT_MEM(in && out_real && out_imag); + + auto is_power_of_two = [](int32_t n) -> bool + { + return (n & (n-1)) == 0 && n > 0; + }; + + // Input shape: [N, H, W] + if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2])) + { + printNodeValidationError("OpRFFT2d: input height and width must be a power of two"); + return 1; + } + + // Output shape: [N, H, W / 2 + 1] + bool output_check = true; + for (int32_t i = 0; i < out_real->getRank(); i++) + { + if (out_real->getShape()[i] != out_imag->getShape()[i]) + { + output_check = false; + break; + } + } + if (!output_check) + { + printNodeValidationError( + "OpRFFT2d: Mismatch between real output shape and imaginary output shape"); + return 1; + } + + if (in->getShape()[0] != out_real->getShape()[0]) { + printNodeValidationError("OpRFFT2d: input and output batch size don't match"); + return 1; + } + if (in->getShape()[1] != out_real->getShape()[1]) { + printNodeValidationError("OpRFFT2d: input and output height don't match"); + return 1; + } + if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) { + printNodeValidationError("OpRFFT2d: output width is expected to match input width / 2 + 1"); + return 1; + } + + return 0; +} + +template +int OpRFFT2d::eval() +{ + int32_t in_batch = in->getShape()[0]; + int32_t in_height = in->getShape()[1]; + int32_t in_width = in->getShape()[2]; + + int32_t out_real_batch = out_real->getShape()[0]; + int32_t out_real_height = out_real->getShape()[1]; + int32_t out_real_width = out_real->getShape()[2]; + + int32_t out_imag_batch = out_imag->getShape()[0]; + int32_t out_imag_height = out_imag->getShape()[1]; + int32_t out_imag_width = out_imag->getShape()[2]; + + DEBUG_INFO(OP, + "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], " + "output_imag.shape=[%d,%d,%d]", + in_batch, in_height, in_width, + out_real_batch, out_real_height, out_real_width, + out_imag_batch, out_imag_height, out_imag_width); + + OutEigenType sum_real, sum_imag, a; + + for (int n = 0; n < in_batch; n++) + { + for (int oy = 0; oy < out_real_height; oy++) + { + for (int ox = 0; ox < out_real_width; ox++) + { + sum_real = 0.0; + sum_imag = 0.0; + for (int iy = 0; iy < in_height; iy++) + { + for (int ix = 0; ix < in_width; ix++) + { + // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType + a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width); + sum_real += this->in->getTensor()(n, iy, ix) * cos(a); + sum_imag += -this->in->getTensor()(n, iy, ix) * sin(a); + } + } + this->out_real->getTensor()(n, oy, ox) = sum_real; + this->out_imag->getTensor()(n, oy, ox) = sum_imag; + } + } + } + + return GraphNode::eval(); +} + template OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -1738,6 +1872,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); + DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32); -- cgit v1.2.1