aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/main.cpp')
-rw-r--r--reference_model/src/main.cpp79
1 files changed, 53 insertions, 26 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 412894c..55a4848 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -64,12 +64,12 @@ int main(int argc, const char** argv)
SIMPLE_FATAL_ERROR("Unable to load graph");
}
- // load json first since it's easier debugging
SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
if (main_gt.initializeGraph())
{
- SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\"");
+ WARNING("Unable to initialize main graph traverser.");
+ goto done;
}
if (main_gt.linkTensorsAndNodes())
@@ -95,49 +95,76 @@ int main(int argc, const char** argv)
if (g_func_config.eval)
{
+ // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
if (main_gt.evaluateAll())
{
- SIMPLE_FATAL_ERROR("Error evaluating network. Giving up.");
+ ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID,
+ "Upon evaluateAll() returning 1, graph can not be VALID.");
+ }
+ else
+ {
+ ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID ||
+ main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
+ "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
}
- // make sure output tensor is evaluated and show its value
- int num_output_tensors = main_gt.getNumOutputTensors();
- bool all_output_valid = true;
- for (int i = 0; i < num_output_tensors; i++)
+ // Only generate output tensor if graph is valid.
+ if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID)
{
- const Tensor* ct = main_gt.getOutputTensor(i);
- ASSERT_MEM(ct);
- if (!ct->getIsValid())
+ // make sure output tensor is evaluated and show its value
+ int num_output_tensors = main_gt.getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
{
- ct->dumpTensorParams(g_func_debug.func_debug_file);
- if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ const Tensor* ct = main_gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
{
- ct->dumpTensor(g_func_debug.func_debug_file);
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
}
- all_output_valid = false;
}
- }
- if (!all_output_valid)
- {
- main_gt.dumpGraph(g_func_debug.func_debug_file);
- SIMPLE_FATAL_ERROR(
- "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
- }
+ if (!all_output_valid)
+ {
+ main_gt.dumpGraph(g_func_debug.func_debug_file);
+ SIMPLE_FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
- if (g_func_config.output_tensors)
- {
- if (writeFinalTensors(main_gt, test_desc))
+ if (g_func_config.output_tensors)
{
- WARNING("Errors encountered in saving output tensors");
+ if (writeFinalTensors(main_gt, test_desc))
+ {
+ WARNING("Errors encountered in saving output tensors");
+ }
}
}
}
done:
+ switch (main_gt.getGraphStatus())
+ {
+ case GraphStatus::TOSA_VALID:
+ // Result is valid.
+ break;
+ case GraphStatus::TOSA_UNPREDICTABLE:
+ fprintf(stderr, "Graph result: UNPREDICTABLE.\n");
+ break;
+ case GraphStatus::TOSA_ERROR:
+ fprintf(stderr, "Graph result: ERROR.\n");
+ break;
+ default:
+ fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus());
+ }
+
func_fini_debug(&g_func_debug);
func_model_config_cleanup();
- return 0;
+ return (int)main_gt.getGraphStatus();
}
int loadGraph(TosaSerializationHandler& tsh, json test_desc)