aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/scatter_gather.h
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/scatter_gather.h')
-rw-r--r--reference_model/src/ops/scatter_gather.h45
1 files changed, 32 insertions, 13 deletions
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index d9b1263..17ea723 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -23,9 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-// input and index can have different rank
-// and infer OutRank statically
-template <int InRank, int IndexRank, DType Dtype>
+template <DType Dtype>
class OpGather : public GraphNode
{
public:
@@ -35,18 +33,39 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr int OutRank = InRank - 1 + IndexRank;
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, InRank>;
- using TIndex = Eigen::Tensor<int32_t, IndexRank>;
- using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
protected:
- TosaAxisAttribute* attribute;
- TosaReference::TensorTemplate<TIn>* in;
- TosaReference::TensorTemplate<TIndex>* index;
- TosaReference::TensorTemplate<TOut>* out;
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TOutput>* output;
+};
+
+template <DType Dtype>
+class OpScatter : public GraphNode
+{
+public:
+ OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpScatter();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
+
+protected:
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values_in;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TValue>* input;
+ TosaReference::TensorTemplate<TOutput>* values_out;
};
}; // namespace TosaReference