aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-22 11:53:48 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-28 20:08:57 +0000
commita4e48ca7b032992ca0110900935c08d7cf860cd3 (patch)
treea58c8617390225ecc107721d9b5ff87c2bdb01b0 /reference_model
parent2226f90d5a6c48a975045bc9e0419113ce764aaf (diff)
downloadreference_model-a4e48ca7b032992ca0110900935c08d7cf860cd3.tar.gz
Update rank limits for SLICE, TILE and TRANSPOSE
Updated to align with corresponding changes to the spec. In addition, some ERROR_IF tests have been updated to match the checks specified by the spec, including: PAD, SLICE, TILE, TRANSPOSE. Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: Ie2c5f48e79a5610eb82739170e25057a63dac1d8
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/data_layout.cc142
-rw-r--r--reference_model/src/ops/data_layout.h4
-rw-r--r--reference_model/src/ops/op_factory.cc42
3 files changed, 128 insertions, 60 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 909c567..ce5b5af 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -127,7 +127,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Pad);
}
@@ -374,7 +374,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 4);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Slice);
}
@@ -398,15 +398,17 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
}
// output and input must be the same types
- if (inputs[0]->matchType(*outputs[0]))
+ if (inputs[0]->matchRankType(*outputs[0]))
{
- printNodeValidationError("Failure to match input and output type");
+ printNodeValidationError("Failure to match input and output rank or type");
return 1;
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ ASSERT_MEM(in && out);
+
ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
"OpSlice: begin array length needs to be rank(input)");
ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
@@ -441,7 +443,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TILE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Tile);
}
@@ -474,6 +476,8 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ ASSERT_MEM(in && out);
+
if (attribute->multiples().size() != Rank)
{
printNodeValidationError("1D list 'multiples' must have size equal to input rank");
@@ -568,6 +572,68 @@ int OpTile<4, Dtype>::eval()
return GraphNode::eval();
}
+template <DType Dtype>
+int OpTile<5, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
+ {
+ int32_t id4 = od4 % this->in->getShape()[4];
+ this->out->getTensor()(od0, od1, od2, od3, od4) =
+ this->in->getTensor()(id0, id1, id2, id3, id4);
+ }
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<6, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
+ {
+ int32_t id4 = od4 % this->in->getShape()[4];
+ for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
+ {
+ int32_t id5 = od5 % this->in->getShape()[5];
+ this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
+ this->in->getTensor()(id0, id1, id2, id3, id4, id5);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
template <int Rank, DType Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
@@ -575,7 +641,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TRANSPOSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Transpose);
}
@@ -673,34 +739,34 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index c6513ae..3a6cb0d 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -186,6 +186,8 @@ DEF_OP_TILE_RANK(1)
DEF_OP_TILE_RANK(2)
DEF_OP_TILE_RANK(3)
DEF_OP_TILE_RANK(4)
+DEF_OP_TILE_RANK(5)
+DEF_OP_TILE_RANK(6)
#undef DEF_OP_TILE_RANK
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 8d84135..1db3974 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -394,31 +394,31 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
break;
case Op_SLICE:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
break;
case Op_TILE:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
break;
case Op_TRANSPOSE:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
break;
// scatter_gather