From cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Thu, 14 Oct 2021 17:09:57 -0700 Subject: More ERROR_IF supports - Also delay tensor allocation after operator being validated ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE - Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table Signed-off-by: Kevin Cheng Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91 --- reference_model/src/main.cpp | 10 +- reference_model/src/ops/activation_funcs.cc | 8 +- reference_model/src/ops/data_layout.cc | 93 ++++++++++++----- reference_model/src/ops/data_layout.h | 1 + reference_model/src/ops/ewise_binary.cc | 20 ++-- reference_model/src/ops/tensor_ops.cc | 156 +++++++++++++++++++++++++--- reference_model/src/ops/type_conversion.cc | 26 ++++- reference_model/src/subgraph_traverser.cc | 121 +++++++++++---------- reference_model/src/subgraph_traverser.h | 5 + thirdparty/serialization_lib | 2 +- 10 files changed, 324 insertions(+), 118 deletions(-) diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index e04a20b..0bf0697 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -84,6 +84,12 @@ int main(int argc, const char** argv) goto done; } + if (main_gt.allocateTensor()) + { + WARNING("Failed to allocate tensor. Evaluation aborted."); + goto done; + } + if (g_func_config.validate_only) { goto done; @@ -251,9 +257,9 @@ int readInputTensors(SubgraphTraverser& gt, json test_desc) DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename); - if (tensor->allocate()) + if (!tensor->is_allocated()) { - WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); return 1; } diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 21677d5..c344bcb 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -25,14 +25,15 @@ using namespace tosa; template int OpClamp::register_fcn() { - switch (Dtype) { case DType_FLOAT: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); - this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + ERROR_IF(max < min, "OpClamp: max smaller than min"); + + this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; } break; case DType_INT8: @@ -40,7 +41,8 @@ int OpClamp::register_fcn() { InEigenType min = (InEigenType)attribute->min_int(); InEigenType max = (InEigenType)attribute->max_int(); - this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + ERROR_IF(max < min, "OpClamp: max smaller than min"); + this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; } break; default: diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 86326f5..f3e80f3 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -51,25 +51,49 @@ int OpConcat::checkTensorAttributes() printNodeValidationError("Concat operator must have at least one input tensor"); return 1; } + + int32_t num_inputs = inputs.size(); + // output and input must be the same types and rank - for (size_t i = 0; i < inputs.size(); i++) + for (int32_t i = 0; i < num_inputs; i++) { if (inputs[i]->matchRankType(*outputs[0])) { - printNodeValidationError("Concat operator input ranks and types must match"); + printNodeValidationError("OpConcat: input ranks and types must match"); return 1; } ins.push_back(dynamic_cast*>(inputs[i])); } - out = dynamic_cast*>(outputs[0]); - - if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size()) + if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank) { - printNodeValidationError("Axis is beyond input tensor rank"); + printNodeValidationError("OpConcat: axis is beyond output tensor rank"); return 1; } + int32_t output_dim_on_axis = 0; + for (int32_t j = 0; j < num_inputs; j++) + { + for (int32_t i = 0; i < Rank; i++) + { + int32_t input_dim = inputs[j]->getShape()[i]; + if (i == attribute->axis()) + { + output_dim_on_axis += input_dim; + } + else if (input_dim != outputs[0]->getShape()[i]) + { + printNodeValidationError("OpConcat: input dimension not matching output dimension"); + return 1; + } + } + } + + ERROR_IF(output_dim_on_axis == outputs[0]->getShape()[attribute->axis()], + "OpConcat: sum of input dimension on axis not equal to output dimension on axis"); + + out = dynamic_cast*>(outputs[0]); + return 0; } @@ -135,14 +159,13 @@ int OpPad::checkTensorAttributes() return 1; } - in = dynamic_cast*>(inputs[0]); - out = dynamic_cast*>(outputs[0]); - TosaReference::TensorTemplate>* paddings = - dynamic_cast>*>(inputs[1]); + in = dynamic_cast*>(inputs[0]); + out = dynamic_cast*>(outputs[0]); + paddings = dynamic_cast>*>(inputs[1]); - for (int i = 0; i < Rank; i++) + if (this->qinfo && Dtype != DType_INT8) { - paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1)); + ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0"); } return 0; @@ -151,6 +174,14 @@ int OpPad::checkTensorAttributes() template int OpPad::eval() { + // Move this to + for (int i = 0; i < Rank; i++) + { + ERROR_IF((paddings->getTensor()(i, 0) < 0) || (paddings->getTensor()(i, 1) < 0), + "OpPad: padding can't be smaller than 0"); + paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1)); + } + InEigenType pad_value = 0; if (this->qinfo) { @@ -202,12 +233,20 @@ int OpReshape::checkTensorAttributes() return 1; } + ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(), + "Input tensor size does not match output tensor size"); + for (uint32_t d = 0; d < OutRank; d++) { if (attribute->shape()[d] == -1) { minusOneCount++; } + else + { + ERROR_IF(attribute->shape()[d] != outputs[0]->getShape()[d], + "OpReshape: new_shape doesn't match output shape"); + } } if (minusOneCount > 1) @@ -358,7 +397,7 @@ OpSlice::OpSlice(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 4); INIT_ATTRIBUTE(Slice); } @@ -391,23 +430,20 @@ int OpSlice::checkTensorAttributes() in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); - for (size_t i = 0; i < attribute->begin().size(); i++) - { - begin_array[i] = attribute->begin()[i]; - } + ERROR_IF((int32_t)attribute->begin().size() != in->getRank(), + "OpSlice: begin array length needs to be rank(input)"); + ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)"); - for (size_t i = 0; i < attribute->size().size(); i++) + for (int32_t i = 0; i < in->getRank(); i++) { - if (attribute->size()[i] != 0) - { - size_array[i] = attribute->size()[i]; - } - else - { - // Tensorflow assigns a zero size to dimensions that are kept - // Eigen expects size to be the full size of the dimension - size_array[i] = in->getTensor().dimension(0); - } + int32_t b = attribute->begin()[i]; + int32_t s = attribute->size()[i]; + ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary"); + ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary"); + ERROR_IF(s <= 0, "OpSlice: output must be positive"); + ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension"); + begin_array[i] = b; + size_array[i] = s; } return 0; @@ -611,6 +647,7 @@ int OpTranspose::eval() for (int32_t d = 0; d < Rank; d++) { perm_array[d] = this->perm_tensor->getTensor().data()[d]; + ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary"); } out->getTensor() = in->getTensor().shuffle(perm_array); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index c9c2602..9f44fc7 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -63,6 +63,7 @@ protected: Eigen::array, Rank> paddings_array; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; + TosaReference::TensorTemplate>* paddings; TosaPadQuantInfo* qinfo; }; diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 023158c..6808604 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -60,26 +60,16 @@ int BinaryNodeBase::checkTensorAttributes() return 1; } - // In some ops, only rank of input and output tensor needs to match - if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL) - { - if (inputs[0]->matchRank(*outputs[0])) - { - std::string err = - "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; - printNodeValidationError(err.c_str()); - return 1; - } - } - // Otherwise both rand/type of input and output must match - else if (inputs[0]->matchRankType(*outputs[0])) + if (inputs[0]->matchRank(*outputs[0])) { std::string err = - "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match"; + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; printNodeValidationError(err.c_str()); return 1; } + ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match"); + a = dynamic_cast*>(inputs[0]); b = dynamic_cast*>(inputs[1]); result = dynamic_cast*>(outputs[0]); @@ -532,6 +522,7 @@ int OpTable::checkTensorAttributes() printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8"); return 1; } + ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8"); } else if (inputs[0]->getDtype() == DType_INT16) { @@ -540,6 +531,7 @@ int OpTable::checkTensorAttributes() printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16"); return 1; } + ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32"); } in = dynamic_cast*>(inputs[0]); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 118d048..be4e4aa 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -115,7 +115,7 @@ OpArgMax::OpArgMax(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_ARGMAX, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 4); INIT_ATTRIBUTE(Axis); } @@ -133,14 +133,60 @@ int OpArgMax::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + if (validateRequiredRank(inputs[0])) + { + return 1; + } + + int32_t output_rank = inputs[0]->getRank() - 1; + if (output_rank != outputs[0]->getRank()) { + printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1"); + return 1; + } + + if (outputs[0]->getDtype() != DType_INT32) + { + printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator"); return 1; } input = dynamic_cast*>(inputs[0]); output = dynamic_cast*>(outputs[0]); + if (attribute->axis() < 0 || attribute->axis() >= input->getRank()) + { + printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]"); + return 1; + } + + bool shape_check = true; + for (int32_t i = 0; i < input->getRank(); i++) + { + if (i < attribute->axis()) + { + if (input->getShape()[i] != output->getShape()[i]) + { + shape_check = false; + break; + } + } + else if (i > attribute->axis()) + { + if (input->getShape()[i] != output->getShape()[i - 1]) + { + shape_check = false; + break; + } + } + // No need to check i == axis + } + if (!shape_check) + { + printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape"); + return 1; + } + return 0; } @@ -411,6 +457,9 @@ int OpConv2d::checkTensorAttributes() printNodeValidationError("OpConv2d: bias tensor must be rank 1"); } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); @@ -434,6 +483,18 @@ int OpConv2d::checkTensorAttributes() return 1; } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t"); + } + } + return 0; } @@ -603,6 +664,9 @@ int OpConv3d::checkTensorAttributes() printNodeValidationError("OpConv3d: bias tensor must be rank 1"); } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); @@ -626,6 +690,18 @@ int OpConv3d::checkTensorAttributes() return 1; } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t"); + } + } + return 0; } @@ -798,6 +874,9 @@ int OpDepthwiseConv2d::checkTensorAttributes() printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); @@ -821,6 +900,18 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 1; } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t"); + } + } + return 0; } @@ -987,8 +1078,23 @@ int OpFullyConnected::checkTensorAttributes() return 1; } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + output = dynamic_cast*>(outputs[0]); + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); + } + } + return 0; } @@ -1059,6 +1165,9 @@ int OpMatMul::checkTensorAttributes() return 1; } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + a = dynamic_cast*>(inputs[0]); b = dynamic_cast*>(inputs[1]); output = dynamic_cast*>(outputs[0]); @@ -1101,6 +1210,12 @@ int OpMatMul::checkTensorAttributes() } W = b->getShape()[2]; + if (Dtype != DType_INT8) + { + ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t"); + ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t"); + } + return 0; } @@ -1291,11 +1406,11 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } -template -OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, - uint64_t id_) +template +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1305,8 +1420,8 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, INIT_QINFO(Conv); } -template -OpTransposeConv2d::~OpTransposeConv2d() +template +OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; @@ -1314,8 +1429,8 @@ OpTransposeConv2d::~OpTransposeConv2d() delete qinfo; } -template -int OpTransposeConv2d::checkTensorAttributes() +template +int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1325,6 +1440,9 @@ int OpTransposeConv2d::checkTensorAttributes() return 1; } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); @@ -1363,11 +1481,23 @@ int OpTransposeConv2d::checkTensorAttributes() } } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); + } + } + return 0; } -template -int OpTransposeConv2d::eval() +template +int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 657eebf..e46ab38 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -30,7 +30,7 @@ OpRescale::OpRescale(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_RESCALE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(0, 4); INIT_ATTRIBUTE(Rescale); } @@ -64,6 +64,30 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(in && out); + if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0)) + { + printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0"); + return 1; + } + + if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0)) + { + printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0"); + return 1; + } + + if (attribute->scale32() && (InDtype == DType_INT48)) + { + printNodeValidationError("OpRescale: Scale set to true but input type is INT48"); + return 1; + } + + if ((!attribute->scale32()) && attribute->double_round()) + { + printNodeValidationError("OpRescale: Scale set to false but double round set to true"); + return 1; + } + return 0; } diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 82de69c..36e0a63 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -14,7 +14,6 @@ // limitations under the License. #include "subgraph_traverser.h" -#include #ifndef SUBGRAPH_ERROR_IF #define SUBGRAPH_ERROR_IF(COND, fmt, ...) \ @@ -119,9 +118,6 @@ int SubgraphTraverser::initializeGraph() { int idx = 0; - // tensor name set which contains all the name used by operator - std::unordered_set used_tensor_name_set; - for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode @@ -264,6 +260,63 @@ int SubgraphTraverser::initializeGraph() idx++; } + for (auto ts : block->GetTensors()) + { + DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); + TosaReference::Tensor* tensor = + TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", + ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); + + // update this->tensors + addTensor(tensor); + } + + DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); + for (auto& input_name : block->GetInputs()) + { + TosaReference::Tensor* tensor = findTensorByName(input_name); + DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str()); + if (tensor) + { + tensor->setIsSubgraphInput(); + inputTensors.push_back(tensor); + } + else + { + SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s", + input_name.c_str()); + } + } + + DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str()); + for (auto& output_name : block->GetOutputs()) + { + TosaReference::Tensor* tensor = findTensorByName(output_name); + DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); + if (tensor) + { + tensor->setIsSubgraphOutput(); + outputTensors.push_back(tensor); + } + else + { + SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s", + output_name.c_str()); + } + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int SubgraphTraverser::allocateTensor() +{ for (auto ts : block->GetTensors()) { // Bail out if tensor is used and any of its dimension is invalid. @@ -280,20 +333,18 @@ int SubgraphTraverser::initializeGraph() } } - DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); - TosaReference::Tensor* tensor = - TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str()); - SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); + DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); + if (tensor->allocate()) + { + FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); + } if (!ts->GetData().empty()) { - if (tensor->allocate()) - { - FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); - } - + DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); switch (ts->GetDtype()) { case DType_INT4: @@ -361,48 +412,6 @@ int SubgraphTraverser::initializeGraph() EnumNamesDType()[ts->GetDtype()]); } } - - // update this->tensors - addTensor(tensor); - } - - DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); - for (auto& input_name : block->GetInputs()) - { - TosaReference::Tensor* tensor = findTensorByName(input_name); - DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str()); - if (tensor) - { - tensor->setIsSubgraphInput(); - inputTensors.push_back(tensor); - } - else - { - SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s", - input_name.c_str()); - } - } - - DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str()); - for (auto& output_name : block->GetOutputs()) - { - TosaReference::Tensor* tensor = findTensorByName(output_name); - DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); - if (tensor) - { - tensor->setIsSubgraphOutput(); - outputTensors.push_back(tensor); - } - else - { - SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s", - output_name.c_str()); - } - } - - if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) - { - dumpNextNodeList(g_func_debug.func_debug_file); } return 0; diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 4be6c1f..d53a4c0 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -21,6 +21,7 @@ #include "ops/op_factory.h" #include "tensor.h" #include "tosa_serialization_handler.h" +#include namespace TosaReference { @@ -54,6 +55,7 @@ public: int linkTensorsAndNodes(); int validateGraph(); + int allocateTensor(); int dumpGraph(FILE* out) const; int dumpNextNodeList(FILE* out) const; @@ -99,6 +101,9 @@ private: // lifetime, although the list itself should only contain unique nodes. std::list nextNodeList; + // tensor name set which contains all the name used by operator + std::unordered_set used_tensor_name_set; + // Maximum number of times to evalute a node before // warning. const int MAX_EVAL_COUNT = 10000; diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index fea5a37..6b078ca 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit fea5a3736d18cb44a8bfb080b8e61d283c3e317c +Subproject commit 6b078cac3ff2b33fd6d01c5e849424fbd9b2ac58 -- cgit v1.2.1