diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2020-11-24 10:26:32 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-11-24 14:06:04 -0800 |
commit | 77d0f76b42c6353459d3271ccba9e41728d066cd (patch) | |
tree | 70f95376de80e249a15b15dba6faa27fb4d48f45 /reference_model/src/ops/scatter_gather.h | |
parent | aee1facbde25caf27cc34e5ec08eb8bba6af8e18 (diff) | |
download | reference_model-77d0f76b42c6353459d3271ccba9e41728d066cd.tar.gz |
Update reference model/serialization library to 0.21.0 with unit tests added/updated
- update tosa.GATHER
- update tosa.RESIZE
- add tosa.SCATTER
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25
Diffstat (limited to 'reference_model/src/ops/scatter_gather.h')
-rw-r--r-- | reference_model/src/ops/scatter_gather.h | 45 |
1 files changed, 32 insertions, 13 deletions
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index d9b1263..17ea723 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -23,9 +23,7 @@ using namespace tosa; namespace TosaReference { -// input and index can have different rank -// and infer OutRank statically -template <int InRank, int IndexRank, DType Dtype> +template <DType Dtype> class OpGather : public GraphNode { public: @@ -35,18 +33,39 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr int OutRank = InRank - 1 + IndexRank; - using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<Dtype>::type; - using TIn = Eigen::Tensor<InEigenType, InRank>; - using TIndex = Eigen::Tensor<int32_t, IndexRank>; - using TOut = Eigen::Tensor<OutEigenType, OutRank>; + using EigenType = typename GetEigenType<Dtype>::type; + using TValue = Eigen::Tensor<EigenType, 3>; + using TIndex = Eigen::Tensor<int32_t, 2>; + using TOutput = Eigen::Tensor<EigenType, 3>; protected: - TosaAxisAttribute* attribute; - TosaReference::TensorTemplate<TIn>* in; - TosaReference::TensorTemplate<TIndex>* index; - TosaReference::TensorTemplate<TOut>* out; + int32_t N, W, K, C; + TosaReference::TensorTemplate<TValue>* values; + TosaReference::TensorTemplate<TIndex>* indices; + TosaReference::TensorTemplate<TOutput>* output; +}; + +template <DType Dtype> +class OpScatter : public GraphNode +{ +public: + OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpScatter(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using EigenType = typename GetEigenType<Dtype>::type; + using TValue = Eigen::Tensor<EigenType, 3>; + using TIndex = Eigen::Tensor<int32_t, 2>; + using TOutput = Eigen::Tensor<EigenType, 3>; + +protected: + int32_t N, W, K, C; + TosaReference::TensorTemplate<TValue>* values_in; + TosaReference::TensorTemplate<TIndex>* indices; + TosaReference::TensorTemplate<TValue>* input; + TosaReference::TensorTemplate<TOutput>* values_out; }; }; // namespace TosaReference |