From a4e48ca7b032992ca0110900935c08d7cf860cd3 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 22 Feb 2023 11:53:48 +0000 Subject: Update rank limits for SLICE, TILE and TRANSPOSE Updated to align with corresponding changes to the spec. In addition, some ERROR_IF tests have been updated to match the checks specified by the spec, including: PAD, SLICE, TILE, TRANSPOSE. Signed-off-by: Luke Hutton Change-Id: Ie2c5f48e79a5610eb82739170e25057a63dac1d8 --- reference_model/src/ops/data_layout.cc | 142 ++++++++++++++++++++++++--------- 1 file changed, 104 insertions(+), 38 deletions(-) (limited to 'reference_model/src/ops/data_layout.cc') diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 909c567..ce5b5af 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -127,7 +127,7 @@ OpPad::OpPad(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Pad); } @@ -374,7 +374,7 @@ OpSlice::OpSlice(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 4); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Slice); } @@ -398,15 +398,17 @@ int OpSlice::checkTensorAttributes() } // output and input must be the same types - if (inputs[0]->matchType(*outputs[0])) + if (inputs[0]->matchRankType(*outputs[0])) { - printNodeValidationError("Failure to match input and output type"); + printNodeValidationError("Failure to match input and output rank or type"); return 1; } in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); + ASSERT_MEM(in && out); + ERROR_IF((int32_t)attribute->start().size() != in->getRank(), "OpSlice: begin array length needs to be rank(input)"); ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)"); @@ -441,7 +443,7 @@ OpTileBase::OpTileBase(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TILE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Tile); } @@ -474,6 +476,8 @@ int OpTileBase::checkTensorAttributes() in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); + ASSERT_MEM(in && out); + if (attribute->multiples().size() != Rank) { printNodeValidationError("1D list 'multiples' must have size equal to input rank"); @@ -568,6 +572,68 @@ int OpTile<4, Dtype>::eval() return GraphNode::eval(); } +template +int OpTile<5, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++) + { + int32_t id3 = od3 % this->in->getShape()[3]; + for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++) + { + int32_t id4 = od4 % this->in->getShape()[4]; + this->out->getTensor()(od0, od1, od2, od3, od4) = + this->in->getTensor()(id0, id1, id2, id3, id4); + } + } + } + } + } + + return GraphNode::eval(); +} + +template +int OpTile<6, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++) + { + int32_t id3 = od3 % this->in->getShape()[3]; + for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++) + { + int32_t id4 = od4 % this->in->getShape()[4]; + for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++) + { + int32_t id5 = od5 % this->in->getShape()[5]; + this->out->getTensor()(od0, od1, od2, od3, od4, od5) = + this->in->getTensor()(id0, id1, id2, id3, id4, id5); + } + } + } + } + } + } + + return GraphNode::eval(); +} + template OpTranspose::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -575,7 +641,7 @@ OpTranspose::OpTranspose(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TRANSPOSE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Transpose); } @@ -673,34 +739,34 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); - -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); - -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); - -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); -- cgit v1.2.1