aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-09-23 15:25:24 -0700
committerEric Kunze <eric.kunze@arm.com>2021-09-27 15:12:04 +0000
commit6097c3db9a74a55d017e5168465c4e10b5793783 (patch)
tree8b5eee42d63c7e341741e7dc8890b12e1ca89c9f
parente86fd34cb3881d5a9c65c1efdbda437314fb83cb (diff)
downloadreference_model-6097c3db9a74a55d017e5168465c4e10b5793783.tar.gz
Op that violates rank requirement now runs to the end and return ERROR instead of bailing out.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I61e163cfdb54057f65dc967394decc3fad53eb89
-rw-r--r--reference_model/src/graph_node.cc32
-rw-r--r--reference_model/src/main.cpp6
2 files changed, 18 insertions, 20 deletions
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
index f765700..4629156 100644
--- a/reference_model/src/graph_node.cc
+++ b/reference_model/src/graph_node.cc
@@ -198,29 +198,23 @@ int GraphNode::validateRequiredRank(const Tensor* t)
{
if (requiredRankMin >= 0 && requiredRankMax >= 0)
{
- if (t->checkRequiredRank(requiredRankMin, requiredRankMax))
- {
- printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
- " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
- std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
- "]. tensorName: " + t->getName());
- return 1;
- }
- else
- {
- return 0;
- }
+ std::string err_message = std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
+ std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
+ "]. tensorName: " + t->getName();
+ ERROR_IF(t->checkRequiredRank(requiredRankMin, requiredRankMax), "%s", err_message.c_str());
+
+ return 0;
}
if (requiredRankMin >= 0)
{
- if (t->checkRequiredRank(requiredRankMin))
- {
- printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
- " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
- std::to_string(requiredRankMin) + ". tensorName: " + t->getName());
- return 1;
- }
+ std::string err_message = std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
+ std::to_string(requiredRankMin) + ". tensorName: " + t->getName();
+ ERROR_IF(t->checkRequiredRank(requiredRankMin), "%s", err_message.c_str());
+
+ return 0;
}
return 0;
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 55a4848..cfae010 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -79,7 +79,11 @@ int main(int argc, const char** argv)
if (main_gt.validateGraph())
{
- SIMPLE_FATAL_ERROR("Failed to validate graph");
+ WARNING("Failed to validate graph. Evaluation aborted.");
+ ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_ERROR ||
+ main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
+ "Upon validateGraph() returning 1, graph can only be ERROR/UNPREDICTABLE.");
+ goto done;
}
if (g_func_config.validate_only)