aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-10-14 17:09:57 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-10-18 18:50:08 +0000
commitcc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2 (patch)
tree2d664f87e3fdd75de8c6794f6f6c8d6364ece6bb /reference_model/src/subgraph_traverser.cc
parente807aae606a78d923a2565052f7c2179e3050650 (diff)
downloadreference_model-cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2.tar.gz
More ERROR_IF supports
- Also delay tensor allocation after operator being validated ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE - Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc121
1 files changed, 65 insertions, 56 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 82de69c..36e0a63 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -14,7 +14,6 @@
// limitations under the License.
#include "subgraph_traverser.h"
-#include <unordered_set>
#ifndef SUBGRAPH_ERROR_IF
#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
@@ -119,9 +118,6 @@ int SubgraphTraverser::initializeGraph()
{
int idx = 0;
- // tensor name set which contains all the name used by operator
- std::unordered_set<std::string> used_tensor_name_set;
-
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
@@ -266,6 +262,63 @@ int SubgraphTraverser::initializeGraph()
for (auto ts : block->GetTensors())
{
+ DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
+ TosaReference::Tensor* tensor =
+ TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
+ ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+
+ // update this->tensors
+ addTensor(tensor);
+ }
+
+ DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
+ for (auto& input_name : block->GetInputs())
+ {
+ TosaReference::Tensor* tensor = findTensorByName(input_name);
+ DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
+ if (tensor)
+ {
+ tensor->setIsSubgraphInput();
+ inputTensors.push_back(tensor);
+ }
+ else
+ {
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
+ input_name.c_str());
+ }
+ }
+
+ DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
+ for (auto& output_name : block->GetOutputs())
+ {
+ TosaReference::Tensor* tensor = findTensorByName(output_name);
+ DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ if (tensor)
+ {
+ tensor->setIsSubgraphOutput();
+ outputTensors.push_back(tensor);
+ }
+ else
+ {
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
+ output_name.c_str());
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::allocateTensor()
+{
+ for (auto ts : block->GetTensors())
+ {
// Bail out if tensor is used and any of its dimension is invalid.
auto got = used_tensor_name_set.find(ts->GetName());
if (got != used_tensor_name_set.end())
@@ -280,20 +333,18 @@ int SubgraphTraverser::initializeGraph()
}
}
- DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
- TosaReference::Tensor* tensor =
- TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+ TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
- SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
- ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+ DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
+ if (tensor->allocate())
+ {
+ FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
+ }
if (!ts->GetData().empty())
{
- if (tensor->allocate())
- {
- FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
- }
-
+ DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
switch (ts->GetDtype())
{
case DType_INT4:
@@ -361,48 +412,6 @@ int SubgraphTraverser::initializeGraph()
EnumNamesDType()[ts->GetDtype()]);
}
}
-
- // update this->tensors
- addTensor(tensor);
- }
-
- DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
- for (auto& input_name : block->GetInputs())
- {
- TosaReference::Tensor* tensor = findTensorByName(input_name);
- DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
- if (tensor)
- {
- tensor->setIsSubgraphInput();
- inputTensors.push_back(tensor);
- }
- else
- {
- SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
- input_name.c_str());
- }
- }
-
- DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
- for (auto& output_name : block->GetOutputs())
- {
- TosaReference::Tensor* tensor = findTensorByName(output_name);
- DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
- if (tensor)
- {
- tensor->setIsSubgraphOutput();
- outputTensors.push_back(tensor);
- }
- else
- {
- SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
- output_name.c_str());
- }
- }
-
- if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
- {
- dumpNextNodeList(g_func_debug.func_debug_file);
}
return 0;