aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-09-28 16:14:52 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-09-30 15:09:13 -0700
commit903763c07f1c8a77783735b05a6a9d722bee1639 (patch)
treeee76f447927e1d07b814391d8f2fbf9a0a7094ab /reference_model/src/subgraph_traverser.cc
parent7fb8fa1826812c305cfcc64e3df256f408fea5a0 (diff)
downloadreference_model-903763c07f1c8a77783735b05a6a9d722bee1639.tar.gz
Add SUBGRAPH_ERROR_IF() to catch graph-level error.
- Also replace SIMPLE_FATAL_ERROR() with FATAL_ERROR() since they're duplicate - Replace FATAL_ERROR()/ASSERT_MSG() with ERROR_IF_SUBGRAPH() if the condition is a graph error FATAL_ERROR()/ASSERT() should only be used by model internal/runtime error like file reading. Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: If1e1e2488054a0ecd800fb0f2ea6487019282500
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc160
1 files changed, 86 insertions, 74 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 4dba669..0002b7b 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -15,6 +15,22 @@
#include "subgraph_traverser.h"
+#ifndef SUBGRAPH_ERROR_IF
+#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
+ if ((COND)) \
+ { \
+ if (this->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \
+ { \
+ this->setGraphStatus(GraphStatus::TOSA_ERROR); \
+ } \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("SUBGRAPH_ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, \
+ __LINE__, __func__, #COND); \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ func_print_backtrace(g_func_debug.func_debug_file); \
+ return 1; \
+ }
+#endif
+
using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
@@ -134,35 +150,43 @@ int SubgraphTraverser::initializeGraph()
if (input_index != -1)
{
- ASSERT_MSG((size_t)input_index < op->GetInputTensorNames().size(),
- "Op=%s, input_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()],
- input_index);
+ SUBGRAPH_ERROR_IF(
+ (size_t)input_index >= op->GetInputTensorNames().size(),
+ "SubgraphTraverser::initializeGraph(): Op=%s, input_index %d must be within [0, num_input - 1]",
+ EnumNamesOp()[op->GetOp()], input_index);
std::string input_name = op->GetInputTensorNames()[input_index];
TosaSerializationTensor* input_tensor = block->GetTensorByName(input_name);
- ASSERT_MSG(input_tensor, "SubgraphTraverser: fail to get input tensor %s from TosaSerializationHandler",
- input_name.c_str());
+ SUBGRAPH_ERROR_IF(
+ !input_tensor,
+ "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
+ input_name.c_str());
input_dtype = input_tensor->GetDtype();
input_rank = input_tensor->GetShape().size();
}
if (weight_index != -1)
{
- ASSERT_MSG((size_t)weight_index < op->GetInputTensorNames().size(),
- "Op=%s, weight_index %d must be within [0, num_input - 1]", EnumNamesOp()[op->GetOp()],
- weight_index);
+ SUBGRAPH_ERROR_IF(
+ (size_t)weight_index >= op->GetInputTensorNames().size(),
+ "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
+ EnumNamesOp()[op->GetOp()], weight_index);
std::string weight_name = op->GetInputTensorNames()[weight_index];
TosaSerializationTensor* weight_tensor = block->GetTensorByName(weight_name);
- ASSERT_MSG(weight_tensor, "SubgraphTraverser: fail to get weight tensor %s from TosaSerializationHandler",
- weight_name.c_str());
+ SUBGRAPH_ERROR_IF(
+ !weight_tensor,
+ "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
+ weight_name.c_str());
weight_dtype = weight_tensor->GetDtype();
weight_rank = weight_tensor->GetShape().size();
}
std::string output_name = op->GetOutputTensorNames()[0];
TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
- ASSERT_MSG(output_tensor, "SubgraphTraverser: fail to get output tensor %s from TosaSerializationHandler",
- output_name.c_str());
+ SUBGRAPH_ERROR_IF(
+ !output_tensor,
+ "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
+ output_name.c_str());
output_dtype = output_tensor->GetDtype();
output_rank = output_tensor->GetShape().size();
@@ -176,28 +200,30 @@ int SubgraphTraverser::initializeGraph()
if (weight_index == -1)
{
fprintf(g_func_debug.func_debug_file,
- "OpFactory could not allocate op %8s input=(%s rank %d) -> (%s rank %d)",
+ "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
+ "-> (%s rank %d)",
EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
EnumNamesDType()[output_dtype], output_rank);
}
else
{
fprintf(g_func_debug.func_debug_file,
- "OpFactory could not allocate op %8s input=(%s rank %d), weight=(%s rank %d) -> (%s rank %d)",
+ "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
+ "weight=(%s rank %d) -> (%s rank %d)",
EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank);
}
for (auto& ts : op->GetInputTensorNames())
{
- fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts.c_str());
+ fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Input: %s\n", ts.c_str());
}
for (auto& ts : op->GetOutputTensorNames())
{
- fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts.c_str());
+ fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Output: %s\n", ts.c_str());
}
- FATAL_ERROR("Unsupported operation type or rank.");
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
}
for (auto& name : op->GetInputTensorNames())
@@ -238,11 +264,14 @@ int SubgraphTraverser::initializeGraph()
TosaReference::Tensor* tensor =
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
+ ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+
if (!ts->GetData().empty())
{
if (tensor->allocate())
{
- SIMPLE_FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
+ FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
}
switch (ts->GetDtype())
@@ -308,7 +337,8 @@ int SubgraphTraverser::initializeGraph()
}
break;
default:
- FATAL_ERROR("Unsupported tensor type %s.", EnumNamesDType()[ts->GetDtype()]);
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
+ EnumNamesDType()[ts->GetDtype()]);
}
}
@@ -328,7 +358,8 @@ int SubgraphTraverser::initializeGraph()
}
else
{
- FATAL_ERROR("loadGraphJson: Failed to find input tensor by name %s", input_name.c_str());
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
+ input_name.c_str());
}
}
@@ -344,7 +375,8 @@ int SubgraphTraverser::initializeGraph()
}
else
{
- FATAL_ERROR("loadGraphJson: Failed to find output tensor by name %s", output_name.c_str());
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
+ output_name.c_str());
}
}
@@ -399,7 +431,8 @@ int SubgraphTraverser::evaluateNextNode()
// Sanity check for never-ending loops
if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
{
- WARNING("Node %lu has been evaluated %d times. Loop suspected.", currNode->getID(), currNode->getEvalCount());
+ WARNING("SubgraphTraverser::evaluateNextNode(): Node %lu has been evaluated %d times. Loop suspected.",
+ currNode->getID(), currNode->getEvalCount());
}
for (auto tensor : currNode->getOutputs())
@@ -407,13 +440,14 @@ int SubgraphTraverser::evaluateNextNode()
if (!tensor->is_allocated())
if (tensor->allocate())
{
- FATAL_ERROR("Failed to allocate Eigen tensor %s", tensor->getName().c_str());
+ FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
+ tensor->getName().c_str());
}
}
if (currNode->eval())
{
- WARNING("Failed to evaluate node: %lu", currNode->getID());
+ WARNING("SubgraphTraverser::evaluateNextNode(): Failed to evaluate node: %lu", currNode->getID());
return 1;
}
@@ -511,7 +545,8 @@ int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor)
{
if (tensor == currTensor || currTensor->getName() == tensor->getName())
{
- FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", tensor->getName().c_str());
+ FATAL_ERROR("SubgraphTraverser::addTensor(): Duplicate tensor or tensor name being added to graph: %s\n",
+ tensor->getName().c_str());
return 1;
}
}
@@ -537,7 +572,7 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
{
if (currNode == newNode)
{
- FATAL_ERROR("Error: duplicate node being added to graph");
+ FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph");
return 1;
}
}
@@ -557,7 +592,7 @@ TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& na
}
}
- WARNING("Unable to find tensor with name: %s\n", name.c_str());
+ WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
return nullptr;
}
@@ -572,52 +607,30 @@ int SubgraphTraverser::linkTensorsAndNodes()
for (std::string& name : currNode->getInputNames())
{
TosaReference::Tensor* t = findTensorByName(name);
- if (!t)
- {
- FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(),
- currNode->getID());
- return 1;
- }
-
- if (currNode->addInputTensor(t))
- {
- FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(),
- currNode->getID());
- return 1;
- }
-
- if (t->addConsumer(currNode))
- {
- FATAL_ERROR("linkTensorsAndNodes: cannot link consumer node %lu to tensor %s\n", currNode->getID(),
- name.c_str());
- return 1;
- }
+ SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
+ name.c_str(), currNode->getID());
+ SUBGRAPH_ERROR_IF(currNode->addInputTensor(t),
+ "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
+ name.c_str(), currNode->getID());
+ SUBGRAPH_ERROR_IF(t->addConsumer(currNode),
+ "SubgraphTraverser::linkTensorsAndNodes(): cannot link consumer node %lu to tensor %s\n",
+ currNode->getID(), name.c_str());
}
// Link outputs/producing nodes
for (std::string& name : currNode->getOutputNames())
{
TosaReference::Tensor* t = findTensorByName(name);
- if (!t)
- {
- FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(),
- currNode->getID());
- return 1;
- }
-
- if (currNode->addOutputTensor(t))
- {
- FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(),
- currNode->getID());
- return 1;
- }
+ SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
+ name.c_str(), currNode->getID());
+ SUBGRAPH_ERROR_IF(currNode->addOutputTensor(t),
+ "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
+ name.c_str(), currNode->getID());
- if (t->setProducer(currNode))
- {
- FATAL_ERROR("linkTensorsAndNodes: cannot link producer node %lu to tensor tensor %s\n",
- currNode->getID(), name.c_str());
- return 1;
- }
+ SUBGRAPH_ERROR_IF(
+ t->setProducer(currNode),
+ "SubgraphTraverser::linkTensorsAndNodes(): cannot link producer node %lu to tensor tensor %s\n",
+ currNode->getID(), name.c_str());
}
}
@@ -640,7 +653,7 @@ int SubgraphTraverser::validateGraph()
{
if (!currTensor->getProducer() && currTensor->getConsumers().empty())
{
- WARNING("Graph inconsistency: TosaReference::Tensor %s has no producers or consumers\n",
+ WARNING("SubgraphTraverser::validateGraph(): TosaReference::Tensor %s has no producers or consumers\n",
currTensor->getName().c_str());
return 1;
}
@@ -653,7 +666,8 @@ int SubgraphTraverser::validateGraph()
// Float-point disallowed
if (dtype == DType_FLOAT)
{
- WARNING("TOSA Base Inference profile selected: All floating point disabled, but %s tensor %s found\n",
+ WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
+ "disabled, but %s tensor %s found\n",
EnumNamesDType()[dtype], currTensor->getName().c_str());
return 1;
}
@@ -665,17 +679,15 @@ int SubgraphTraverser::validateGraph()
}
else
{
- FATAL_ERROR("TOSA profile not recognized: %d", g_func_config.tosa_profile);
+ FATAL_ERROR("SubgraphTraverser::validateGraph(): TOSA profile not recognized: %d",
+ g_func_config.tosa_profile);
}
}
for (GraphNode* currNode : nodes)
{
- if (currNode->checkTensorAttributes())
- {
- WARNING("TosaReference::Tensor attribute check failed");
- return 1;
- }
+ SUBGRAPH_ERROR_IF(currNode->checkTensorAttributes(),
+ "SubgraphTraverser::validateGraph(): TosaReference::Tensor attribute check failed");
}
if (outputTensors.size() <= 0)