diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 0d2b3eb..9ef4a58 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -249,6 +249,29 @@ protected: }; template <DType Dtype> +class OpFFT2d : public GraphNode +{ +public: + OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpFFT2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, 3>; + using TOut = Eigen::Tensor<OutEigenType, 3>; + +protected: + TosaReference::TensorTemplate<TIn>* in_real; + TosaReference::TensorTemplate<TIn>* in_imag; + TosaReference::TensorTemplate<TOut>* out_real; + TosaReference::TensorTemplate<TOut>* out_imag; + tosa::TosaFFTAttribute* attribute; +}; + +template <DType Dtype> class OpRFFT2d : public GraphNode { public: |