aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r--reference_model/src/ops/tensor_ops.h23
1 files changed, 23 insertions, 0 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 0d2b3eb..9ef4a58 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -249,6 +249,29 @@ protected:
};
template <DType Dtype>
+class OpFFT2d : public GraphNode
+{
+public:
+ OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpFFT2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 3>;
+ using TOut = Eigen::Tensor<OutEigenType, 3>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in_real;
+ TosaReference::TensorTemplate<TIn>* in_imag;
+ TosaReference::TensorTemplate<TOut>* out_real;
+ TosaReference::TensorTemplate<TOut>* out_imag;
+ tosa::TosaFFTAttribute* attribute;
+};
+
+template <DType Dtype>
class OpRFFT2d : public GraphNode
{
public: