aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/scatter_gather.h
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2020-11-24 10:26:32 -0800
committerKevin Cheng <kevin.cheng@arm.com>2020-11-24 14:06:04 -0800
commit77d0f76b42c6353459d3271ccba9e41728d066cd (patch)
tree70f95376de80e249a15b15dba6faa27fb4d48f45 /reference_model/src/ops/scatter_gather.h
parentaee1facbde25caf27cc34e5ec08eb8bba6af8e18 (diff)
downloadreference_model-77d0f76b42c6353459d3271ccba9e41728d066cd.tar.gz
Update reference model/serialization library to 0.21.0 with unit tests added/updated
- update tosa.GATHER - update tosa.RESIZE - add tosa.SCATTER Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25
Diffstat (limited to 'reference_model/src/ops/scatter_gather.h')
-rw-r--r--reference_model/src/ops/scatter_gather.h45
1 files changed, 32 insertions, 13 deletions
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index d9b1263..17ea723 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -23,9 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-// input and index can have different rank
-// and infer OutRank statically
-template <int InRank, int IndexRank, DType Dtype>
+template <DType Dtype>
class OpGather : public GraphNode
{
public:
@@ -35,18 +33,39 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr int OutRank = InRank - 1 + IndexRank;
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, InRank>;
- using TIndex = Eigen::Tensor<int32_t, IndexRank>;
- using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
protected:
- TosaAxisAttribute* attribute;
- TosaReference::TensorTemplate<TIn>* in;
- TosaReference::TensorTemplate<TIndex>* index;
- TosaReference::TensorTemplate<TOut>* out;
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TOutput>* output;
+};
+
+template <DType Dtype>
+class OpScatter : public GraphNode
+{
+public:
+ OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpScatter();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
+
+protected:
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values_in;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TValue>* input;
+ TosaReference::TensorTemplate<TOutput>* values_out;
};
}; // namespace TosaReference