aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/main.cpp
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-29 15:32:19 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-08-20 18:07:06 +0100
commitacb550f4410ae861e53cae27a9feb4b11d45769f (patch)
treeae2f4ec558c2cdf1afa020b80a09d7ab4be5ef6d /reference_model/src/main.cpp
parent68e7aee65bda5ac03fa7def753b7dc7462554793 (diff)
downloadreference_model-acb550f4410ae861e53cae27a9feb4b11d45769f.tar.gz
Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF()
- Adding return code enum class: {VALID, UNPREDICTABLE, ERROR} - Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16()) - Update setExpectedFailure() to setExpectedReturnCode() on test generation script - Update test regression script to interface with reference model change Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1
Diffstat (limited to 'reference_model/src/main.cpp')
-rw-r--r--reference_model/src/main.cpp79
1 files changed, 53 insertions, 26 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 412894c..55a4848 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -64,12 +64,12 @@ int main(int argc, const char** argv)
SIMPLE_FATAL_ERROR("Unable to load graph");
}
- // load json first since it's easier debugging
SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
if (main_gt.initializeGraph())
{
- SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\"");
+ WARNING("Unable to initialize main graph traverser.");
+ goto done;
}
if (main_gt.linkTensorsAndNodes())
@@ -95,49 +95,76 @@ int main(int argc, const char** argv)
if (g_func_config.eval)
{
+ // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
if (main_gt.evaluateAll())
{
- SIMPLE_FATAL_ERROR("Error evaluating network. Giving up.");
+ ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID,
+ "Upon evaluateAll() returning 1, graph can not be VALID.");
+ }
+ else
+ {
+ ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID ||
+ main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
+ "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
}
- // make sure output tensor is evaluated and show its value
- int num_output_tensors = main_gt.getNumOutputTensors();
- bool all_output_valid = true;
- for (int i = 0; i < num_output_tensors; i++)
+ // Only generate output tensor if graph is valid.
+ if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID)
{
- const Tensor* ct = main_gt.getOutputTensor(i);
- ASSERT_MEM(ct);
- if (!ct->getIsValid())
+ // make sure output tensor is evaluated and show its value
+ int num_output_tensors = main_gt.getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
{
- ct->dumpTensorParams(g_func_debug.func_debug_file);
- if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ const Tensor* ct = main_gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
{
- ct->dumpTensor(g_func_debug.func_debug_file);
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
}
- all_output_valid = false;
}
- }
- if (!all_output_valid)
- {
- main_gt.dumpGraph(g_func_debug.func_debug_file);
- SIMPLE_FATAL_ERROR(
- "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
- }
+ if (!all_output_valid)
+ {
+ main_gt.dumpGraph(g_func_debug.func_debug_file);
+ SIMPLE_FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
- if (g_func_config.output_tensors)
- {
- if (writeFinalTensors(main_gt, test_desc))
+ if (g_func_config.output_tensors)
{
- WARNING("Errors encountered in saving output tensors");
+ if (writeFinalTensors(main_gt, test_desc))
+ {
+ WARNING("Errors encountered in saving output tensors");
+ }
}
}
}
done:
+ switch (main_gt.getGraphStatus())
+ {
+ case GraphStatus::TOSA_VALID:
+ // Result is valid.
+ break;
+ case GraphStatus::TOSA_UNPREDICTABLE:
+ fprintf(stderr, "Graph result: UNPREDICTABLE.\n");
+ break;
+ case GraphStatus::TOSA_ERROR:
+ fprintf(stderr, "Graph result: ERROR.\n");
+ break;
+ default:
+ fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus());
+ }
+
func_fini_debug(&g_func_debug);
func_model_config_cleanup();
- return 0;
+ return (int)main_gt.getGraphStatus();
}
int loadGraph(TosaSerializationHandler& tsh, json test_desc)