From 6097c3db9a74a55d017e5168465c4e10b5793783 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Thu, 23 Sep 2021 15:25:24 -0700 Subject: Op that violates rank requirement now runs to the end and return ERROR instead of bailing out. Signed-off-by: Kevin Cheng Change-Id: I61e163cfdb54057f65dc967394decc3fad53eb89 --- reference_model/src/graph_node.cc | 32 +++++++++++++------------------- reference_model/src/main.cpp | 6 +++++- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc index f765700..4629156 100644 --- a/reference_model/src/graph_node.cc +++ b/reference_model/src/graph_node.cc @@ -198,29 +198,23 @@ int GraphNode::validateRequiredRank(const Tensor* t) { if (requiredRankMin >= 0 && requiredRankMax >= 0) { - if (t->checkRequiredRank(requiredRankMin, requiredRankMax)) - { - printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + - " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + - std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + - "]. tensorName: " + t->getName()); - return 1; - } - else - { - return 0; - } + std::string err_message = std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + + std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + + "]. tensorName: " + t->getName(); + ERROR_IF(t->checkRequiredRank(requiredRankMin, requiredRankMax), "%s", err_message.c_str()); + + return 0; } if (requiredRankMin >= 0) { - if (t->checkRequiredRank(requiredRankMin)) - { - printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + - " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + - std::to_string(requiredRankMin) + ". tensorName: " + t->getName()); - return 1; - } + std::string err_message = std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + + std::to_string(requiredRankMin) + ". tensorName: " + t->getName(); + ERROR_IF(t->checkRequiredRank(requiredRankMin), "%s", err_message.c_str()); + + return 0; } return 0; diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 55a4848..cfae010 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -79,7 +79,11 @@ int main(int argc, const char** argv) if (main_gt.validateGraph()) { - SIMPLE_FATAL_ERROR("Failed to validate graph"); + WARNING("Failed to validate graph. Evaluation aborted."); + ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_ERROR || + main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon validateGraph() returning 1, graph can only be ERROR/UNPREDICTABLE."); + goto done; } if (g_func_config.validate_only) -- cgit v1.2.1