aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc142
1 files changed, 104 insertions, 38 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 909c567..ce5b5af 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -127,7 +127,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Pad);
}
@@ -374,7 +374,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 4);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Slice);
}
@@ -398,15 +398,17 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
}
// output and input must be the same types
- if (inputs[0]->matchType(*outputs[0]))
+ if (inputs[0]->matchRankType(*outputs[0]))
{
- printNodeValidationError("Failure to match input and output type");
+ printNodeValidationError("Failure to match input and output rank or type");
return 1;
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ ASSERT_MEM(in && out);
+
ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
"OpSlice: begin array length needs to be rank(input)");
ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
@@ -441,7 +443,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TILE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Tile);
}
@@ -474,6 +476,8 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ ASSERT_MEM(in && out);
+
if (attribute->multiples().size() != Rank)
{
printNodeValidationError("1D list 'multiples' must have size equal to input rank");
@@ -568,6 +572,68 @@ int OpTile<4, Dtype>::eval()
return GraphNode::eval();
}
+template <DType Dtype>
+int OpTile<5, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
+ {
+ int32_t id4 = od4 % this->in->getShape()[4];
+ this->out->getTensor()(od0, od1, od2, od3, od4) =
+ this->in->getTensor()(id0, id1, id2, id3, id4);
+ }
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<6, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
+ {
+ int32_t id4 = od4 % this->in->getShape()[4];
+ for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
+ {
+ int32_t id5 = od5 % this->in->getShape()[5];
+ this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
+ this->in->getTensor()(id0, id1, id2, id3, id4, id5);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
template <int Rank, DType Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
@@ -575,7 +641,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TRANSPOSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Transpose);
}
@@ -673,34 +739,34 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);