diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-06-29 15:32:19 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-08-20 18:07:06 +0100 |
commit | acb550f4410ae861e53cae27a9feb4b11d45769f (patch) | |
tree | ae2f4ec558c2cdf1afa020b80a09d7ab4be5ef6d /reference_model/src/subgraph_traverser.cc | |
parent | 68e7aee65bda5ac03fa7def753b7dc7462554793 (diff) | |
download | reference_model-acb550f4410ae861e53cae27a9feb4b11d45769f.tar.gz |
Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF()
- Adding return code enum class: {VALID, UNPREDICTABLE, ERROR}
- Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes
Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16())
- Update setExpectedFailure() to setExpectedReturnCode() on test generation script
- Update test regression script to interface with reference model change
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index bdf6fbc..ef7bae6 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -21,6 +21,8 @@ using namespace tosa; SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) { + graph_status = GraphStatus::TOSA_VALID; + block = _block; tsh = _tsh; @@ -166,7 +168,7 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, + GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); if (!node) { @@ -221,16 +223,25 @@ int SubgraphTraverser::initializeGraph() for (auto ts : block->GetTensors()) { + // Bail out if any dimension is invalid. + for (auto& dim : ts->GetShape()) + { + if (dim <= 0) + { + this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); + return 1; + } + } DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); TosaReference::Tensor* tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + if (!ts->GetData().empty()) { if (tensor->allocate()) { - WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); - return 1; + SIMPLE_FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); } switch (ts->GetDtype()) @@ -316,7 +327,7 @@ int SubgraphTraverser::initializeGraph() } else { - FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str()); + FATAL_ERROR("loadGraphJson: Failed to find input tensor by name %s", input_name.c_str()); } } @@ -332,7 +343,7 @@ int SubgraphTraverser::initializeGraph() } else { - FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str()); + FATAL_ERROR("loadGraphJson: Failed to find output tensor by name %s", output_name.c_str()); } } @@ -395,13 +406,14 @@ int SubgraphTraverser::evaluateNextNode() if (!tensor->is_allocated()) if (tensor->allocate()) { - FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str()); + FATAL_ERROR("Failed to allocate Eigen tensor %s", tensor->getName().c_str()); } } if (currNode->eval()) { - FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID()); + WARNING("Failed to evaluate node: %lu", currNode->getID()); + return 1; } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output |