aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/main.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-05-12 21:42:19 +0000
committerEric Kunze <eric.kunze@arm.com>2023-05-18 01:18:18 +0000
commit307392a4962cc659f7104867a56816a011694a44 (patch)
tree9957c65666b4be69e9920a4ae7e1925ad0f254d7 /reference_model/src/main.cpp
parent264f7faa59709ffa8117541f5d55c99c5dba967d (diff)
downloadreference_model-307392a4962cc659f7104867a56816a011694a44.tar.gz
Add abs calculations under precise_mode
This adds a second run of reference model under precise_mode when test_desc.json contains a "compliance" dictionary which contains a "mode" entry with value "dot product". In this second run, abs_mode will be set to true, which causes: 1. evaluation will take absolute values of inputs for these operators: conv2d, conv3d, depthwise_conv2d, fully_connected, matmul, transpose_conv2d, fft2d, rfft2d reduce_sum, avg_pool2d 2. output files will have prefix "bounds_" prepended to them Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I7070ecc7ead2d2ea3375c44663d653c6772b88e0
Diffstat (limited to 'reference_model/src/main.cpp')
-rw-r--r--reference_model/src/main.cpp147
1 files changed, 92 insertions, 55 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index aad07cb..55e6d67 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -33,10 +33,12 @@ using namespace tosa;
using json = nlohmann::json;
int initTestDesc(json& test_desc);
-int readInputTensors(SubgraphTraverser& gt, json test_desc);
-int writeFinalTensors(SubgraphTraverser& gt, json test_desc);
-int loadGraph(TosaSerializationHandler& tsh, json test_desc);
+int readInputTensors(SubgraphTraverser& gt, json& test_desc);
+int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix);
+int loadGraph(TosaSerializationHandler& tsh, json& test_desc);
void parse_value(const std::string& text, tosa_level_t& value);
+const std::string getResultFilenamePrefix();
+bool isComplianceModeDotProduct(json& test_desc);
int main(int argc, char** argv)
{
@@ -84,44 +86,51 @@ int main(int argc, char** argv)
FATAL_ERROR("Unable to load graph");
}
- SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr);
+ GraphStatus status = GraphStatus::TOSA_VALID;
- if (main_gt.initializeGraph())
+ // max of 2 runs, second run only happens when precise_mode is set, to do an abs_mode run
+ for (int run = 0; run < 2; run++)
{
- WARNING("Unable to initialize main graph traverser.");
- goto done;
- }
+ SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr);
- if (main_gt.linkTensorsAndNodes())
- {
- WARNING("Failed to link tensors and nodes");
- goto done;
- }
+ if (main_gt.initializeGraph())
+ {
+ WARNING("Unable to initialize main graph traverser.");
+ goto done;
+ }
- if (main_gt.validateGraph())
- {
- WARNING("Failed to validate graph. Evaluation aborted.");
- goto done;
- }
+ if (main_gt.linkTensorsAndNodes())
+ {
+ WARNING("Failed to link tensors and nodes");
+ goto done;
+ }
- if (main_gt.allocateTensor())
- {
- WARNING("Failed to allocate tensor. Evaluation aborted.");
- goto done;
- }
+ if (main_gt.validateGraph())
+ {
+ WARNING("Failed to validate graph. Evaluation aborted.");
+ goto done;
+ }
- if (g_func_config.validate_only)
- {
- goto done;
- }
+ if (main_gt.allocateTensor())
+ {
+ WARNING("Failed to allocate tensor. Evaluation aborted.");
+ goto done;
+ }
- if (readInputTensors(main_gt, test_desc))
- {
- FATAL_ERROR("Unable to read input tensors");
- }
+ if (g_func_config.validate_only)
+ {
+ goto done;
+ }
- if (g_func_config.eval)
- {
+ if (readInputTensors(main_gt, test_desc))
+ {
+ FATAL_ERROR("Unable to read input tensors");
+ }
+
+ if (!g_func_config.eval)
+ {
+ goto done;
+ }
// evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
if (main_gt.evaluateAll())
@@ -165,36 +174,47 @@ int main(int argc, char** argv)
if (g_func_config.output_tensors)
{
- if (writeFinalTensors(main_gt, test_desc))
+ if (writeFinalTensors(main_gt, test_desc, getResultFilenamePrefix()))
{
WARNING("Errors encountered in saving output tensors");
}
}
}
- }
-done:
- switch (main_gt.getGraphStatus())
- {
- case GraphStatus::TOSA_VALID:
- // Result is valid.
- break;
- case GraphStatus::TOSA_UNPREDICTABLE:
- fprintf(stderr, "Graph result: UNPREDICTABLE.\n");
- break;
- case GraphStatus::TOSA_ERROR:
- fprintf(stderr, "Graph result: ERROR.\n");
- break;
- default:
- fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus());
+ done:
+ status = main_gt.getGraphStatus();
+ switch (status)
+ {
+ case GraphStatus::TOSA_VALID:
+ // Result is valid.
+ break;
+ case GraphStatus::TOSA_UNPREDICTABLE:
+ fprintf(stderr, "Graph result: UNPREDICTABLE.\n");
+ break;
+ case GraphStatus::TOSA_ERROR:
+ fprintf(stderr, "Graph result: ERROR.\n");
+ break;
+ default:
+ fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus());
+ }
+
+ if (status == GraphStatus::TOSA_VALID && g_func_config.eval && g_func_config.precise_mode &&
+ isComplianceModeDotProduct(test_desc))
+ {
+ // first run result is valid, in precise mode and eval is true: turn on abs_mode for second run
+ g_func_config.abs_mode = true;
+ continue;
+ }
+
+ // otherwise, do only one run
+ break;
}
g_func_debug.fini_debug();
-
- return (int)main_gt.getGraphStatus();
+ return (int)status;
}
-int loadGraph(TosaSerializationHandler& tsh, json test_desc)
+int loadGraph(TosaSerializationHandler& tsh, json& test_desc)
{
char graph_fullname[1024];
const std::string error_msg1 = "Check \"tosa_file\" in .json specified by --tosa_desc";
@@ -248,7 +268,7 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc)
return 0;
}
-int readInputTensors(SubgraphTraverser& gt, json test_desc)
+int readInputTensors(SubgraphTraverser& gt, json& test_desc)
{
int tensorCount = gt.getNumInputTensors();
Tensor* tensor;
@@ -314,7 +334,24 @@ int readInputTensors(SubgraphTraverser& gt, json test_desc)
return 0;
}
-int writeFinalTensors(SubgraphTraverser& gt, json test_desc)
+const std::string getResultFilenamePrefix()
+{
+ return g_func_config.abs_mode ? "bounds_" : "";
+}
+
+// returns true iff test_desc contains a dictionay, "compliance",
+// which contains entry "mode" whose value is "dot product"
+bool isComplianceModeDotProduct(json& test_desc)
+{
+ if (test_desc.contains("compliance") && test_desc["compliance"].contains("mode") &&
+ test_desc["compliance"]["mode"] == "dot product")
+ {
+ return true;
+ }
+ return false;
+}
+
+int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix)
{
int tensorCount = gt.getNumOutputTensors();
const Tensor* tensor;
@@ -338,7 +375,7 @@ int writeFinalTensors(SubgraphTraverser& gt, json test_desc)
return 1;
}
- snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir.c_str(),
+ snprintf(filename, sizeof(filename), "%s/%s%s", g_func_config.output_dir.c_str(), filename_prefix.c_str(),
test_desc["ofm_file"][i].get<std::string>().c_str());
DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);