aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc26
1 files changed, 19 insertions, 7 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index bdf6fbc..ef7bae6 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -21,6 +21,8 @@ using namespace tosa;
SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
{
+ graph_status = GraphStatus::TOSA_VALID;
+
block = _block;
tsh = _tsh;
@@ -166,7 +168,7 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype,
+ GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
if (!node)
{
@@ -221,16 +223,25 @@ int SubgraphTraverser::initializeGraph()
for (auto ts : block->GetTensors())
{
+ // Bail out if any dimension is invalid.
+ for (auto& dim : ts->GetShape())
+ {
+ if (dim <= 0)
+ {
+ this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+ return 1;
+ }
+ }
DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
TosaReference::Tensor* tensor =
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+
if (!ts->GetData().empty())
{
if (tensor->allocate())
{
- WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
- return 1;
+ SIMPLE_FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
}
switch (ts->GetDtype())
@@ -316,7 +327,7 @@ int SubgraphTraverser::initializeGraph()
}
else
{
- FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str());
+ FATAL_ERROR("loadGraphJson: Failed to find input tensor by name %s", input_name.c_str());
}
}
@@ -332,7 +343,7 @@ int SubgraphTraverser::initializeGraph()
}
else
{
- FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str());
+ FATAL_ERROR("loadGraphJson: Failed to find output tensor by name %s", output_name.c_str());
}
}
@@ -395,13 +406,14 @@ int SubgraphTraverser::evaluateNextNode()
if (!tensor->is_allocated())
if (tensor->allocate())
{
- FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str());
+ FATAL_ERROR("Failed to allocate Eigen tensor %s", tensor->getName().c_str());
}
}
if (currNode->eval())
{
- FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID());
+ WARNING("Failed to evaluate node: %lu", currNode->getID());
+ return 1;
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output