diff options
author | Luke Hutton <luke.hutton@arm.com> | 2023-01-10 14:50:31 +0000 |
---|---|---|
committer | Luke Hutton <luke.hutton@arm.com> | 2023-01-24 13:40:17 +0000 |
commit | 261b7b62b959a6c7312d810d9152069fdff69f3e (patch) | |
tree | 2be25cefa14cd21379a9fc6f6c499622b6de8bf8 /reference_model/src/ops/tensor_ops.h | |
parent | c253e64710f22016894c0e3ac4e9eb76d62cb2f9 (diff) | |
download | reference_model-261b7b62b959a6c7312d810d9152069fdff69f3e.tar.gz |
Add RFFT2d to the reference model
Includes:
* RFFT2d reference implementation
* TFLite framework tests
* Basic TOSA tests
* Serialization submodule upgrade with support for FFT/RFFT
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index fd6dd25..ed9a55c 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -248,6 +248,27 @@ protected: tosa::TosaPoolAttribute* attribute; }; +template <DType Dtype> +class OpRFFT2d : public GraphNode +{ +public: + OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpRFFT2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, 3>; + using TOut = Eigen::Tensor<OutEigenType, 3>; + +protected: + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out_real; + TosaReference::TensorTemplate<TOut>* out_imag; +}; + template <DType InDtype, DType WeightDtype, DType AccDtype> class OpTransposeConv2d : public GraphNode { |