diff options
author | TatWai Chong <tatwai.chong@arm.com> | 2024-01-24 22:57:07 -0800 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-01-31 19:09:14 +0000 |
commit | 01f937a27a3b56bca622f94af7201c98dfebeb43 (patch) | |
tree | f5ca356ee8795ef12208d0210d68fed051b0c88e | |
parent | 39431cbbbc12d065225a9622ce49b4eaff6c934c (diff) | |
download | reference_model-01f937a27a3b56bca622f94af7201c98dfebeb43.tar.gz |
Change the start and size of slice to tosa shape type
This offers dynamism support for slice op.
Change-Id: I4521c072c663a01e03e575e0cbbc8671c832f646
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 36 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.h | 13 |
2 files changed, 24 insertions, 25 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 14f2918..ec9614a 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -422,18 +422,13 @@ template <int Rank, TOSA_REF_TYPE Dtype> OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_SLICE, id_) { - setRequiredOperands(1, 1); + setRequiredOperands(3, 1); setRequiredRank(1); - - INIT_ATTRIBUTE(Slice); } template <int Rank, TOSA_REF_TYPE Dtype> OpSlice<Rank, Dtype>::~OpSlice() -{ - if (attribute) - delete attribute; -} +{} template <int Rank, TOSA_REF_TYPE Dtype> int OpSlice<Rank, Dtype>::checkTensorAttributes() @@ -457,19 +452,26 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes() return 1; } - in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + start = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[1]); + size = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[2]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); ASSERT_MEM(in && out); - ERROR_IF((int32_t)attribute->start().size() != in->getRank(), - "OpSlice: begin array length needs to be rank(input)"); - ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)"); + return 0; +} + +template <int Rank, TOSA_REF_TYPE Dtype> +int OpSlice<Rank, Dtype>::eval() +{ + ERROR_IF(start->getElementCount() != in->getRank(), "OpSlice: start array length needs to be rank(input)"); + ERROR_IF(size->getElementCount() != in->getRank(), "OpSlice: size array length needs to be rank(input)"); for (int32_t i = 0; i < in->getRank(); i++) { - int32_t b = attribute->start()[i]; - int32_t s = attribute->size()[i]; + int32_t b = start->getTensor()(i); + int32_t s = size->getTensor()(i); ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary"); ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary"); ERROR_IF(s <= 0, "OpSlice: output must be positive"); @@ -478,12 +480,6 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes() size_array[i] = s; } - return 0; -} - -template <int Rank, TOSA_REF_TYPE Dtype> -int OpSlice<Rank, Dtype>::eval() -{ out->getTensor() = in->getTensor().slice(begin_array, size_array); return GraphNode::eval(); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index e085b8e..6ab5ebd 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -145,15 +145,18 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<Dtype>::type; - using TIn = Eigen::Tensor<InEigenType, Rank>; - using TOut = Eigen::Tensor<OutEigenType, Rank>; + using InEigenType = typename GetEigenType<Dtype>::type; + using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TInShape = Eigen::Tensor<InEigenShapeType, 1>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; protected: - TosaSliceAttribute* attribute; Eigen::array<Eigen::Index, Rank> begin_array; Eigen::array<Eigen::Index, Rank> size_array; + TosaReference::TensorTemplate<TInShape>* start; + TosaReference::TensorTemplate<TInShape>* size; TosaReference::TensorTemplate<TIn>* in; TosaReference::TensorTemplate<TOut>* out; }; |