From 307392a4962cc659f7104867a56816a011694a44 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 12 May 2023 21:42:19 +0000 Subject: Add abs calculations under precise_mode This adds a second run of reference model under precise_mode when test_desc.json contains a "compliance" dictionary which contains a "mode" entry with value "dot product". In this second run, abs_mode will be set to true, which causes: 1. evaluation will take absolute values of inputs for these operators: conv2d, conv3d, depthwise_conv2d, fully_connected, matmul, transpose_conv2d, fft2d, rfft2d reduce_sum, avg_pool2d 2. output files will have prefix "bounds_" prepended to them Signed-off-by: Tai Ly Change-Id: I7070ecc7ead2d2ea3375c44663d653c6772b88e0 --- reference_model/src/main.cpp | 147 +++++++++++++++++++++++++++---------------- 1 file changed, 92 insertions(+), 55 deletions(-) (limited to 'reference_model/src/main.cpp') diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index aad07cb..55e6d67 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -33,10 +33,12 @@ using namespace tosa; using json = nlohmann::json; int initTestDesc(json& test_desc); -int readInputTensors(SubgraphTraverser& gt, json test_desc); -int writeFinalTensors(SubgraphTraverser& gt, json test_desc); -int loadGraph(TosaSerializationHandler& tsh, json test_desc); +int readInputTensors(SubgraphTraverser& gt, json& test_desc); +int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix); +int loadGraph(TosaSerializationHandler& tsh, json& test_desc); void parse_value(const std::string& text, tosa_level_t& value); +const std::string getResultFilenamePrefix(); +bool isComplianceModeDotProduct(json& test_desc); int main(int argc, char** argv) { @@ -84,44 +86,51 @@ int main(int argc, char** argv) FATAL_ERROR("Unable to load graph"); } - SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr); + GraphStatus status = GraphStatus::TOSA_VALID; - if (main_gt.initializeGraph()) + // max of 2 runs, second run only happens when precise_mode is set, to do an abs_mode run + for (int run = 0; run < 2; run++) { - WARNING("Unable to initialize main graph traverser."); - goto done; - } + SubgraphTraverser main_gt(tsh.GetMainRegion()->GetBlockByName("main"), &tsh, nullptr); - if (main_gt.linkTensorsAndNodes()) - { - WARNING("Failed to link tensors and nodes"); - goto done; - } + if (main_gt.initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + goto done; + } - if (main_gt.validateGraph()) - { - WARNING("Failed to validate graph. Evaluation aborted."); - goto done; - } + if (main_gt.linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + goto done; + } - if (main_gt.allocateTensor()) - { - WARNING("Failed to allocate tensor. Evaluation aborted."); - goto done; - } + if (main_gt.validateGraph()) + { + WARNING("Failed to validate graph. Evaluation aborted."); + goto done; + } - if (g_func_config.validate_only) - { - goto done; - } + if (main_gt.allocateTensor()) + { + WARNING("Failed to allocate tensor. Evaluation aborted."); + goto done; + } - if (readInputTensors(main_gt, test_desc)) - { - FATAL_ERROR("Unable to read input tensors"); - } + if (g_func_config.validate_only) + { + goto done; + } - if (g_func_config.eval) - { + if (readInputTensors(main_gt, test_desc)) + { + FATAL_ERROR("Unable to read input tensors"); + } + + if (!g_func_config.eval) + { + goto done; + } // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. if (main_gt.evaluateAll()) @@ -165,36 +174,47 @@ int main(int argc, char** argv) if (g_func_config.output_tensors) { - if (writeFinalTensors(main_gt, test_desc)) + if (writeFinalTensors(main_gt, test_desc, getResultFilenamePrefix())) { WARNING("Errors encountered in saving output tensors"); } } } - } -done: - switch (main_gt.getGraphStatus()) - { - case GraphStatus::TOSA_VALID: - // Result is valid. - break; - case GraphStatus::TOSA_UNPREDICTABLE: - fprintf(stderr, "Graph result: UNPREDICTABLE.\n"); - break; - case GraphStatus::TOSA_ERROR: - fprintf(stderr, "Graph result: ERROR.\n"); - break; - default: - fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus()); + done: + status = main_gt.getGraphStatus(); + switch (status) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + fprintf(stderr, "Graph result: UNPREDICTABLE.\n"); + break; + case GraphStatus::TOSA_ERROR: + fprintf(stderr, "Graph result: ERROR.\n"); + break; + default: + fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus()); + } + + if (status == GraphStatus::TOSA_VALID && g_func_config.eval && g_func_config.precise_mode && + isComplianceModeDotProduct(test_desc)) + { + // first run result is valid, in precise mode and eval is true: turn on abs_mode for second run + g_func_config.abs_mode = true; + continue; + } + + // otherwise, do only one run + break; } g_func_debug.fini_debug(); - - return (int)main_gt.getGraphStatus(); + return (int)status; } -int loadGraph(TosaSerializationHandler& tsh, json test_desc) +int loadGraph(TosaSerializationHandler& tsh, json& test_desc) { char graph_fullname[1024]; const std::string error_msg1 = "Check \"tosa_file\" in .json specified by --tosa_desc"; @@ -248,7 +268,7 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc) return 0; } -int readInputTensors(SubgraphTraverser& gt, json test_desc) +int readInputTensors(SubgraphTraverser& gt, json& test_desc) { int tensorCount = gt.getNumInputTensors(); Tensor* tensor; @@ -314,7 +334,24 @@ int readInputTensors(SubgraphTraverser& gt, json test_desc) return 0; } -int writeFinalTensors(SubgraphTraverser& gt, json test_desc) +const std::string getResultFilenamePrefix() +{ + return g_func_config.abs_mode ? "bounds_" : ""; +} + +// returns true iff test_desc contains a dictionay, "compliance", +// which contains entry "mode" whose value is "dot product" +bool isComplianceModeDotProduct(json& test_desc) +{ + if (test_desc.contains("compliance") && test_desc["compliance"].contains("mode") && + test_desc["compliance"]["mode"] == "dot product") + { + return true; + } + return false; +} + +int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix) { int tensorCount = gt.getNumOutputTensors(); const Tensor* tensor; @@ -338,7 +375,7 @@ int writeFinalTensors(SubgraphTraverser& gt, json test_desc) return 1; } - snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir.c_str(), + snprintf(filename, sizeof(filename), "%s/%s%s", g_func_config.output_dir.c_str(), filename_prefix.c_str(), test_desc["ofm_file"][i].get().c_str()); DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); -- cgit v1.2.1