diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-06-29 15:32:19 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-08-20 18:07:06 +0100 |
commit | acb550f4410ae861e53cae27a9feb4b11d45769f (patch) | |
tree | ae2f4ec558c2cdf1afa020b80a09d7ab4be5ef6d /reference_model/src/main.cpp | |
parent | 68e7aee65bda5ac03fa7def753b7dc7462554793 (diff) | |
download | reference_model-acb550f4410ae861e53cae27a9feb4b11d45769f.tar.gz |
Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF()
- Adding return code enum class: {VALID, UNPREDICTABLE, ERROR}
- Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes
Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16())
- Update setExpectedFailure() to setExpectedReturnCode() on test generation script
- Update test regression script to interface with reference model change
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1
Diffstat (limited to 'reference_model/src/main.cpp')
-rw-r--r-- | reference_model/src/main.cpp | 79 |
1 files changed, 53 insertions, 26 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 412894c..55a4848 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -64,12 +64,12 @@ int main(int argc, const char** argv) SIMPLE_FATAL_ERROR("Unable to load graph"); } - // load json first since it's easier debugging SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh); if (main_gt.initializeGraph()) { - SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\""); + WARNING("Unable to initialize main graph traverser."); + goto done; } if (main_gt.linkTensorsAndNodes()) @@ -95,49 +95,76 @@ int main(int argc, const char** argv) if (g_func_config.eval) { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. if (main_gt.evaluateAll()) { - SIMPLE_FATAL_ERROR("Error evaluating network. Giving up."); + ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID || + main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); } - // make sure output tensor is evaluated and show its value - int num_output_tensors = main_gt.getNumOutputTensors(); - bool all_output_valid = true; - for (int i = 0; i < num_output_tensors; i++) + // Only generate output tensor if graph is valid. + if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID) { - const Tensor* ct = main_gt.getOutputTensor(i); - ASSERT_MEM(ct); - if (!ct->getIsValid()) + // make sure output tensor is evaluated and show its value + int num_output_tensors = main_gt.getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) { - ct->dumpTensorParams(g_func_debug.func_debug_file); - if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + const Tensor* ct = main_gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) { - ct->dumpTensor(g_func_debug.func_debug_file); + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; } - all_output_valid = false; } - } - if (!all_output_valid) - { - main_gt.dumpGraph(g_func_debug.func_debug_file); - SIMPLE_FATAL_ERROR( - "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); - } + if (!all_output_valid) + { + main_gt.dumpGraph(g_func_debug.func_debug_file); + SIMPLE_FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } - if (g_func_config.output_tensors) - { - if (writeFinalTensors(main_gt, test_desc)) + if (g_func_config.output_tensors) { - WARNING("Errors encountered in saving output tensors"); + if (writeFinalTensors(main_gt, test_desc)) + { + WARNING("Errors encountered in saving output tensors"); + } } } } done: + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + fprintf(stderr, "Graph result: UNPREDICTABLE.\n"); + break; + case GraphStatus::TOSA_ERROR: + fprintf(stderr, "Graph result: ERROR.\n"); + break; + default: + fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus()); + } + func_fini_debug(&g_func_debug); func_model_config_cleanup(); - return 0; + return (int)main_gt.getGraphStatus(); } int loadGraph(TosaSerializationHandler& tsh, json test_desc) |