aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-01-24 22:57:07 -0800
committerEric Kunze <eric.kunze@arm.com>2024-01-31 19:09:14 +0000
commit01f937a27a3b56bca622f94af7201c98dfebeb43 (patch)
treef5ca356ee8795ef12208d0210d68fed051b0c88e
parent39431cbbbc12d065225a9622ce49b4eaff6c934c (diff)
downloadreference_model-01f937a27a3b56bca622f94af7201c98dfebeb43.tar.gz
Change the start and size of slice to tosa shape type
This offers dynamism support for slice op. Change-Id: I4521c072c663a01e03e575e0cbbc8671c832f646 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--reference_model/src/ops/data_layout.cc36
-rw-r--r--reference_model/src/ops/data_layout.h13
2 files changed, 24 insertions, 25 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 14f2918..ec9614a 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -422,18 +422,13 @@ template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_SLICE, id_)
{
- setRequiredOperands(1, 1);
+ setRequiredOperands(3, 1);
setRequiredRank(1);
-
- INIT_ATTRIBUTE(Slice);
}
template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::~OpSlice()
-{
- if (attribute)
- delete attribute;
-}
+{}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::checkTensorAttributes()
@@ -457,19 +452,26 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ start = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[1]);
+ size = dynamic_cast<TosaReference::TensorTemplate<TInShape>*>(inputs[2]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(in && out);
- ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
- "OpSlice: begin array length needs to be rank(input)");
- ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
+ return 0;
+}
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpSlice<Rank, Dtype>::eval()
+{
+ ERROR_IF(start->getElementCount() != in->getRank(), "OpSlice: start array length needs to be rank(input)");
+ ERROR_IF(size->getElementCount() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
for (int32_t i = 0; i < in->getRank(); i++)
{
- int32_t b = attribute->start()[i];
- int32_t s = attribute->size()[i];
+ int32_t b = start->getTensor()(i);
+ int32_t s = size->getTensor()(i);
ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
ERROR_IF(s <= 0, "OpSlice: output must be positive");
@@ -478,12 +480,6 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
size_array[i] = s;
}
- return 0;
-}
-
-template <int Rank, TOSA_REF_TYPE Dtype>
-int OpSlice<Rank, Dtype>::eval()
-{
out->getTensor() = in->getTensor().slice(begin_array, size_array);
return GraphNode::eval();
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index e085b8e..6ab5ebd 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -145,15 +145,18 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, Rank>;
- using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using InEigenShapeType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TInShape = Eigen::Tensor<InEigenShapeType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
protected:
- TosaSliceAttribute* attribute;
Eigen::array<Eigen::Index, Rank> begin_array;
Eigen::array<Eigen::Index, Rank> size_array;
+ TosaReference::TensorTemplate<TInShape>* start;
+ TosaReference::TensorTemplate<TInShape>* size;
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
};