aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-01-10 14:50:31 +0000
committerLuke Hutton <luke.hutton@arm.com>2023-01-24 13:40:17 +0000
commit261b7b62b959a6c7312d810d9152069fdff69f3e (patch)
tree2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /reference_model/src/ops/tensor_ops.cc
parentc253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff)
downloadreference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz
Add RFFT2d to the reference model
Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc138
1 files changed, 137 insertions, 1 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b9ac94a..dff9e08 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -1453,6 +1453,140 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
+template <DType Dtype>
+OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_RFFT2D, id_)
+{
+ setRequiredOperands(1, 2);
+ setRequiredRank(3);
+}
+
+template <DType Dtype>
+OpRFFT2d<Dtype>::~OpRFFT2d() {}
+
+
+template <DType Dtype>
+int OpRFFT2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) ||
+ validateRequiredRank(outputs[1]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
+ {
+ printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
+
+ ASSERT_MEM(in && out_real && out_imag);
+
+ auto is_power_of_two = [](int32_t n) -> bool
+ {
+ return (n & (n-1)) == 0 && n > 0;
+ };
+
+ // Input shape: [N, H, W]
+ if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2]))
+ {
+ printNodeValidationError("OpRFFT2d: input height and width must be a power of two");
+ return 1;
+ }
+
+ // Output shape: [N, H, W / 2 + 1]
+ bool output_check = true;
+ for (int32_t i = 0; i < out_real->getRank(); i++)
+ {
+ if (out_real->getShape()[i] != out_imag->getShape()[i])
+ {
+ output_check = false;
+ break;
+ }
+ }
+ if (!output_check)
+ {
+ printNodeValidationError(
+ "OpRFFT2d: Mismatch between real output shape and imaginary output shape");
+ return 1;
+ }
+
+ if (in->getShape()[0] != out_real->getShape()[0]) {
+ printNodeValidationError("OpRFFT2d: input and output batch size don't match");
+ return 1;
+ }
+ if (in->getShape()[1] != out_real->getShape()[1]) {
+ printNodeValidationError("OpRFFT2d: input and output height don't match");
+ return 1;
+ }
+ if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) {
+ printNodeValidationError("OpRFFT2d: output width is expected to match input width / 2 + 1");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpRFFT2d<Dtype>::eval()
+{
+ int32_t in_batch = in->getShape()[0];
+ int32_t in_height = in->getShape()[1];
+ int32_t in_width = in->getShape()[2];
+
+ int32_t out_real_batch = out_real->getShape()[0];
+ int32_t out_real_height = out_real->getShape()[1];
+ int32_t out_real_width = out_real->getShape()[2];
+
+ int32_t out_imag_batch = out_imag->getShape()[0];
+ int32_t out_imag_height = out_imag->getShape()[1];
+ int32_t out_imag_width = out_imag->getShape()[2];
+
+ DEBUG_INFO(OP,
+ "perform OpRFFT2d, input.shape=[%d,%d,%d], output_real.shape=[%d,%d,%d], "
+ "output_imag.shape=[%d,%d,%d]",
+ in_batch, in_height, in_width,
+ out_real_batch, out_real_height, out_real_width,
+ out_imag_batch, out_imag_height, out_imag_width);
+
+ OutEigenType sum_real, sum_imag, a;
+
+ for (int n = 0; n < in_batch; n++)
+ {
+ for (int oy = 0; oy < out_real_height; oy++)
+ {
+ for (int ox = 0; ox < out_real_width; ox++)
+ {
+ sum_real = 0.0;
+ sum_imag = 0.0;
+ for (int iy = 0; iy < in_height; iy++)
+ {
+ for (int ix = 0; ix < in_width; ix++)
+ {
+ // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
+ a = 2 * M_PI * ((iy * (OutEigenType)oy) / in_height + (ix * (OutEigenType)ox) / in_width);
+ sum_real += this->in->getTensor()(n, iy, ix) * cos(a);
+ sum_imag += -this->in->getTensor()(n, iy, ix) * sin(a);
+ }
+ }
+ this->out_real->getTensor()(n, oy, ox) = sum_real;
+ this->out_imag->getTensor()(n, oy, ox) = sum_imag;
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
template <DType InDtype, DType WeightDtype, DType AccDtype>
OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
@@ -1738,6 +1872,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
+
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32);