diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-09-23 15:25:24 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2021-09-27 15:12:04 +0000 |
commit | 6097c3db9a74a55d017e5168465c4e10b5793783 (patch) | |
tree | 8b5eee42d63c7e341741e7dc8890b12e1ca89c9f /reference_model/src | |
parent | e86fd34cb3881d5a9c65c1efdbda437314fb83cb (diff) | |
download | reference_model-6097c3db9a74a55d017e5168465c4e10b5793783.tar.gz |
Op that violates rank requirement now runs to the end and return ERROR instead of bailing out.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I61e163cfdb54057f65dc967394decc3fad53eb89
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/graph_node.cc | 32 | ||||
-rw-r--r-- | reference_model/src/main.cpp | 6 |
2 files changed, 18 insertions, 20 deletions
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc index f765700..4629156 100644 --- a/reference_model/src/graph_node.cc +++ b/reference_model/src/graph_node.cc @@ -198,29 +198,23 @@ int GraphNode::validateRequiredRank(const Tensor* t) { if (requiredRankMin >= 0 && requiredRankMax >= 0) { - if (t->checkRequiredRank(requiredRankMin, requiredRankMax)) - { - printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + - " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + - std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + - "]. tensorName: " + t->getName()); - return 1; - } - else - { - return 0; - } + std::string err_message = std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + + std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + + "]. tensorName: " + t->getName(); + ERROR_IF(t->checkRequiredRank(requiredRankMin, requiredRankMax), "%s", err_message.c_str()); + + return 0; } if (requiredRankMin >= 0) { - if (t->checkRequiredRank(requiredRankMin)) - { - printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + - " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + - std::to_string(requiredRankMin) + ". tensorName: " + t->getName()); - return 1; - } + std::string err_message = std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + + std::to_string(requiredRankMin) + ". tensorName: " + t->getName(); + ERROR_IF(t->checkRequiredRank(requiredRankMin), "%s", err_message.c_str()); + + return 0; } return 0; diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 55a4848..cfae010 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -79,7 +79,11 @@ int main(int argc, const char** argv) if (main_gt.validateGraph()) { - SIMPLE_FATAL_ERROR("Failed to validate graph"); + WARNING("Failed to validate graph. Evaluation aborted."); + ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_ERROR || + main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon validateGraph() returning 1, graph can only be ERROR/UNPREDICTABLE."); + goto done; } if (g_func_config.validate_only) |