From acb550f4410ae861e53cae27a9feb4b11d45769f Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 29 Jun 2021 15:32:19 -0700 Subject: Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF() - Adding return code enum class: {VALID, UNPREDICTABLE, ERROR} - Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16()) - Update setExpectedFailure() to setExpectedReturnCode() on test generation script - Update test regression script to interface with reference model change Signed-off-by: Kevin Cheng Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1 --- reference_model/src/main.cpp | 79 +++++++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 26 deletions(-) (limited to 'reference_model/src/main.cpp') diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 412894c..55a4848 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -64,12 +64,12 @@ int main(int argc, const char** argv) SIMPLE_FATAL_ERROR("Unable to load graph"); } - // load json first since it's easier debugging SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh); if (main_gt.initializeGraph()) { - SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\""); + WARNING("Unable to initialize main graph traverser."); + goto done; } if (main_gt.linkTensorsAndNodes()) @@ -95,49 +95,76 @@ int main(int argc, const char** argv) if (g_func_config.eval) { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. if (main_gt.evaluateAll()) { - SIMPLE_FATAL_ERROR("Error evaluating network. Giving up."); + ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID || + main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); } - // make sure output tensor is evaluated and show its value - int num_output_tensors = main_gt.getNumOutputTensors(); - bool all_output_valid = true; - for (int i = 0; i < num_output_tensors; i++) + // Only generate output tensor if graph is valid. + if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID) { - const Tensor* ct = main_gt.getOutputTensor(i); - ASSERT_MEM(ct); - if (!ct->getIsValid()) + // make sure output tensor is evaluated and show its value + int num_output_tensors = main_gt.getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) { - ct->dumpTensorParams(g_func_debug.func_debug_file); - if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + const Tensor* ct = main_gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) { - ct->dumpTensor(g_func_debug.func_debug_file); + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; } - all_output_valid = false; } - } - if (!all_output_valid) - { - main_gt.dumpGraph(g_func_debug.func_debug_file); - SIMPLE_FATAL_ERROR( - "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); - } + if (!all_output_valid) + { + main_gt.dumpGraph(g_func_debug.func_debug_file); + SIMPLE_FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } - if (g_func_config.output_tensors) - { - if (writeFinalTensors(main_gt, test_desc)) + if (g_func_config.output_tensors) { - WARNING("Errors encountered in saving output tensors"); + if (writeFinalTensors(main_gt, test_desc)) + { + WARNING("Errors encountered in saving output tensors"); + } } } } done: + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + fprintf(stderr, "Graph result: UNPREDICTABLE.\n"); + break; + case GraphStatus::TOSA_ERROR: + fprintf(stderr, "Graph result: ERROR.\n"); + break; + default: + fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus()); + } + func_fini_debug(&g_func_debug); func_model_config_cleanup(); - return 0; + return (int)main_gt.getGraphStatus(); } int loadGraph(TosaSerializationHandler& tsh, json test_desc) -- cgit v1.2.1