diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-10-14 17:09:57 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-10-18 18:50:08 +0000 |
commit | cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2 (patch) | |
tree | 2d664f87e3fdd75de8c6794f6f6c8d6364ece6bb /reference_model/src/subgraph_traverser.cc | |
parent | e807aae606a78d923a2565052f7c2179e3050650 (diff) | |
download | reference_model-cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2.tar.gz |
More ERROR_IF supports
- Also delay tensor allocation after operator being validated
ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE
- Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 121 |
1 files changed, 65 insertions, 56 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 82de69c..36e0a63 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -14,7 +14,6 @@ // limitations under the License. #include "subgraph_traverser.h" -#include <unordered_set> #ifndef SUBGRAPH_ERROR_IF #define SUBGRAPH_ERROR_IF(COND, fmt, ...) \ @@ -119,9 +118,6 @@ int SubgraphTraverser::initializeGraph() { int idx = 0; - // tensor name set which contains all the name used by operator - std::unordered_set<std::string> used_tensor_name_set; - for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode @@ -266,6 +262,63 @@ int SubgraphTraverser::initializeGraph() for (auto ts : block->GetTensors()) { + DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); + TosaReference::Tensor* tensor = + TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", + ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); + + // update this->tensors + addTensor(tensor); + } + + DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); + for (auto& input_name : block->GetInputs()) + { + TosaReference::Tensor* tensor = findTensorByName(input_name); + DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str()); + if (tensor) + { + tensor->setIsSubgraphInput(); + inputTensors.push_back(tensor); + } + else + { + SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s", + input_name.c_str()); + } + } + + DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str()); + for (auto& output_name : block->GetOutputs()) + { + TosaReference::Tensor* tensor = findTensorByName(output_name); + DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); + if (tensor) + { + tensor->setIsSubgraphOutput(); + outputTensors.push_back(tensor); + } + else + { + SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s", + output_name.c_str()); + } + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int SubgraphTraverser::allocateTensor() +{ + for (auto ts : block->GetTensors()) + { // Bail out if tensor is used and any of its dimension is invalid. auto got = used_tensor_name_set.find(ts->GetName()); if (got != used_tensor_name_set.end()) @@ -280,20 +333,18 @@ int SubgraphTraverser::initializeGraph() } } - DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); - TosaReference::Tensor* tensor = - TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str()); - SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); + DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); + if (tensor->allocate()) + { + FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); + } if (!ts->GetData().empty()) { - if (tensor->allocate()) - { - FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); - } - + DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); switch (ts->GetDtype()) { case DType_INT4: @@ -361,48 +412,6 @@ int SubgraphTraverser::initializeGraph() EnumNamesDType()[ts->GetDtype()]); } } - - // update this->tensors - addTensor(tensor); - } - - DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); - for (auto& input_name : block->GetInputs()) - { - TosaReference::Tensor* tensor = findTensorByName(input_name); - DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str()); - if (tensor) - { - tensor->setIsSubgraphInput(); - inputTensors.push_back(tensor); - } - else - { - SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s", - input_name.c_str()); - } - } - - DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str()); - for (auto& output_name : block->GetOutputs()) - { - TosaReference::Tensor* tensor = findTensorByName(output_name); - DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); - if (tensor) - { - tensor->setIsSubgraphOutput(); - outputTensors.push_back(tensor); - } - else - { - SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s", - output_name.c_str()); - } - } - - if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) - { - dumpNextNodeList(g_func_debug.func_debug_file); } return 0; |