diff options
Diffstat (limited to 'reference_model/src/ops/scatter_gather.h')
-rw-r--r-- | reference_model/src/ops/scatter_gather.h | 45 |
1 files changed, 32 insertions, 13 deletions
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index d9b1263..17ea723 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -23,9 +23,7 @@ using namespace tosa; namespace TosaReference { -// input and index can have different rank -// and infer OutRank statically -template <int InRank, int IndexRank, DType Dtype> +template <DType Dtype> class OpGather : public GraphNode { public: @@ -35,18 +33,39 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr int OutRank = InRank - 1 + IndexRank; - using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<Dtype>::type; - using TIn = Eigen::Tensor<InEigenType, InRank>; - using TIndex = Eigen::Tensor<int32_t, IndexRank>; - using TOut = Eigen::Tensor<OutEigenType, OutRank>; + using EigenType = typename GetEigenType<Dtype>::type; + using TValue = Eigen::Tensor<EigenType, 3>; + using TIndex = Eigen::Tensor<int32_t, 2>; + using TOutput = Eigen::Tensor<EigenType, 3>; protected: - TosaAxisAttribute* attribute; - TosaReference::TensorTemplate<TIn>* in; - TosaReference::TensorTemplate<TIndex>* index; - TosaReference::TensorTemplate<TOut>* out; + int32_t N, W, K, C; + TosaReference::TensorTemplate<TValue>* values; + TosaReference::TensorTemplate<TIndex>* indices; + TosaReference::TensorTemplate<TOutput>* output; +}; + +template <DType Dtype> +class OpScatter : public GraphNode +{ +public: + OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpScatter(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using EigenType = typename GetEigenType<Dtype>::type; + using TValue = Eigen::Tensor<EigenType, 3>; + using TIndex = Eigen::Tensor<int32_t, 2>; + using TOutput = Eigen::Tensor<EigenType, 3>; + +protected: + int32_t N, W, K, C; + TosaReference::TensorTemplate<TValue>* values_in; + TosaReference::TensorTemplate<TIndex>* indices; + TosaReference::TensorTemplate<TValue>* input; + TosaReference::TensorTemplate<TOutput>* values_out; }; }; // namespace TosaReference |