diff options
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index bdf6fbc..ef7bae6 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -21,6 +21,8 @@ using namespace tosa; SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) { + graph_status = GraphStatus::TOSA_VALID; + block = _block; tsh = _tsh; @@ -166,7 +168,7 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, + GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); if (!node) { @@ -221,16 +223,25 @@ int SubgraphTraverser::initializeGraph() for (auto ts : block->GetTensors()) { + // Bail out if any dimension is invalid. + for (auto& dim : ts->GetShape()) + { + if (dim <= 0) + { + this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); + return 1; + } + } DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); TosaReference::Tensor* tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + if (!ts->GetData().empty()) { if (tensor->allocate()) { - WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); - return 1; + SIMPLE_FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); } switch (ts->GetDtype()) @@ -316,7 +327,7 @@ int SubgraphTraverser::initializeGraph() } else { - FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str()); + FATAL_ERROR("loadGraphJson: Failed to find input tensor by name %s", input_name.c_str()); } } @@ -332,7 +343,7 @@ int SubgraphTraverser::initializeGraph() } else { - FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str()); + FATAL_ERROR("loadGraphJson: Failed to find output tensor by name %s", output_name.c_str()); } } @@ -395,13 +406,14 @@ int SubgraphTraverser::evaluateNextNode() if (!tensor->is_allocated()) if (tensor->allocate()) { - FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str()); + FATAL_ERROR("Failed to allocate Eigen tensor %s", tensor->getName().c_str()); } } if (currNode->eval()) { - FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID()); + WARNING("Failed to evaluate node: %lu", currNode->getID()); + return 1; } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output |