aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.h
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-06 14:54:18 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-10 20:01:04 +0000
commit5728713fca4f6e2dff60dad3689e471545e563d2 (patch)
tree848421100f82a33ff57ee3205c369ad75737f7d3 /reference_model/src/ops/tensor_ops.h
parentc1e25f5755997e65ac1a360ec1e875db06040d8d (diff)
downloadreference_model-5728713fca4f6e2dff60dad3689e471545e563d2.tar.gz
Add FFT2d to the reference model
Includes: * FFT2d reference implementation * Basic TOSA tests Change-Id: Ie79fcb713542345d550ec013646810c1e890e388 Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r--reference_model/src/ops/tensor_ops.h23
1 files changed, 23 insertions, 0 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 0d2b3eb..9ef4a58 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -249,6 +249,29 @@ protected:
};
template <DType Dtype>
+class OpFFT2d : public GraphNode
+{
+public:
+ OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpFFT2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 3>;
+ using TOut = Eigen::Tensor<OutEigenType, 3>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in_real;
+ TosaReference::TensorTemplate<TIn>* in_imag;
+ TosaReference::TensorTemplate<TOut>* out_real;
+ TosaReference::TensorTemplate<TOut>* out_imag;
+ tosa::TosaFFTAttribute* attribute;
+};
+
+template <DType Dtype>
class OpRFFT2d : public GraphNode
{
public: