From 77d0f76b42c6353459d3271ccba9e41728d066cd Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 24 Nov 2020 10:26:32 -0800 Subject: Update reference model/serialization library to 0.21.0 with unit tests added/updated - update tosa.GATHER - update tosa.RESIZE - add tosa.SCATTER Signed-off-by: Kevin Cheng Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25 --- reference_model/src/ops/scatter_gather.h | 45 +++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 13 deletions(-) (limited to 'reference_model/src/ops/scatter_gather.h') diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index d9b1263..17ea723 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -23,9 +23,7 @@ using namespace tosa; namespace TosaReference { -// input and index can have different rank -// and infer OutRank statically -template +template class OpGather : public GraphNode { public: @@ -35,18 +33,39 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr int OutRank = InRank - 1 + IndexRank; - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; - using TIn = Eigen::Tensor; - using TIndex = Eigen::Tensor; - using TOut = Eigen::Tensor; + using EigenType = typename GetEigenType::type; + using TValue = Eigen::Tensor; + using TIndex = Eigen::Tensor; + using TOutput = Eigen::Tensor; protected: - TosaAxisAttribute* attribute; - TosaReference::TensorTemplate* in; - TosaReference::TensorTemplate* index; - TosaReference::TensorTemplate* out; + int32_t N, W, K, C; + TosaReference::TensorTemplate* values; + TosaReference::TensorTemplate* indices; + TosaReference::TensorTemplate* output; +}; + +template +class OpScatter : public GraphNode +{ +public: + OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpScatter(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using EigenType = typename GetEigenType::type; + using TValue = Eigen::Tensor; + using TIndex = Eigen::Tensor; + using TOutput = Eigen::Tensor; + +protected: + int32_t N, W, K, C; + TosaReference::TensorTemplate* values_in; + TosaReference::TensorTemplate* indices; + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* values_out; }; }; // namespace TosaReference -- cgit v1.2.1