From cd79f0e06bf53c2c0fee39ee916bb6d79f177b57 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Thu, 3 Jun 2021 15:00:34 -0700 Subject: Rewrite model frontend to be json-driven. Change-Id: Iac786eff96183938d2fd11cde9313c6e8e1270a5 --- reference_model/src/main.cpp | 258 ++++++++++++++++++++++++++++++------------- 1 file changed, 182 insertions(+), 76 deletions(-) (limited to 'reference_model/src/main.cpp') diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 240d913..0d6d8a3 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -22,16 +22,21 @@ #include #include +#include +#include + using namespace TosaReference; using namespace tosa; +using json = nlohmann::json; // Global instantiation of configuration and debug objects func_config_t g_func_config; func_debug_t g_func_debug; -int readInputTensors(SubgraphTraverser& gt); -int writeFinalTensors(SubgraphTraverser& gt); -int loadGraph(TosaSerializationHandler& tsh); +int initTestDesc(json& test_desc); +int readInputTensors(SubgraphTraverser& gt, json test_desc); +int writeFinalTensors(SubgraphTraverser& gt, json test_desc); +int loadGraph(TosaSerializationHandler& tsh, json test_desc); int main(int argc, const char** argv) { @@ -46,7 +51,15 @@ int main(int argc, const char** argv) return 1; } - if (loadGraph(tsh)) + json test_desc; + + // Initialize test descriptor + if (initTestDesc(test_desc)) + { + SIMPLE_FATAL_ERROR("Unable to load test json"); + } + + if (loadGraph(tsh, test_desc)) { SIMPLE_FATAL_ERROR("Unable to load graph"); } @@ -74,7 +87,7 @@ int main(int argc, const char** argv) goto done; } - if (readInputTensors(main_gt)) + if (readInputTensors(main_gt, test_desc)) { SIMPLE_FATAL_ERROR("Unable to read input tensors"); } @@ -113,7 +126,7 @@ int main(int argc, const char** argv) if (g_func_config.output_tensors) { - if (writeFinalTensors(main_gt)) + if (writeFinalTensors(main_gt, test_desc)) { WARNING("Errors encountered in saving output tensors"); } @@ -127,16 +140,17 @@ done: return 0; } -int loadGraph(TosaSerializationHandler& tsh) +int loadGraph(TosaSerializationHandler& tsh, json test_desc) { char graph_fullname[1024]; - snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.subgraph_dir, g_func_config.subgraph_file); + snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.flatbuffer_dir, + test_desc["tosa_file"].get().c_str()); if (strlen(graph_fullname) <= 2) { func_model_print_help(stderr); - SIMPLE_FATAL_ERROR("Missing required argument: Check -Csubgraph_file="); + SIMPLE_FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc="); } const char JSON_EXT[] = ".json"; @@ -163,131 +177,223 @@ int loadGraph(TosaSerializationHandler& tsh) if (tsh.LoadFileJson(graph_fullname)) { - SIMPLE_FATAL_ERROR("\nError loading JSON graph file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=", - graph_fullname); + SIMPLE_FATAL_ERROR( + "\nError loading JSON graph file: %s\nCheck -Ctest_desc=, -Ctosa_file= and -Cflatbuffer_dir=", + graph_fullname); } } else { if (tsh.LoadFileTosaFlatbuffer(graph_fullname)) { - SIMPLE_FATAL_ERROR("\nError loading TOSA flatbuffer file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=", - graph_fullname); + SIMPLE_FATAL_ERROR( + "\nError loading TOSA flatbuffer file: %s\nCheck -Ctest_desc=, -Ctosa_file= and -Cflatbuffer_dir=", + graph_fullname); } } return 0; } -int readInputTensors(SubgraphTraverser& gt) +int readInputTensors(SubgraphTraverser& gt, json test_desc) { int tensorCount = gt.getNumInputTensors(); Tensor* tensor; char filename[1024]; - // assuming filename doesn't have colons(:) - std::map input_tensor_map; - std::string raw_str(g_func_config.input_tensor); - std::string name, npy; - bool last_pair = false; - - std::string::size_type pair_start = 0, pair_end, colons_pos; - do + try { - pair_end = raw_str.find(',', pair_start); - if (pair_end == std::string::npos) - last_pair = true; + if ((tensorCount != (int)test_desc["ifm_name"].size()) || (tensorCount != (int)test_desc["ifm_file"].size())) + { + WARNING("Number of input tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount, + test_desc["ifm_name"].size(), test_desc["ifm_file"].size()); + return 1; + } - colons_pos = raw_str.find(':', pair_start); + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getInputTensorByName(test_desc["ifm_name"][i].get()); + if (!tensor) + { + WARNING("Unable to find input tensor %s", test_desc["ifm_name"][i].get().c_str()); + return 1; + } - name = raw_str.substr(pair_start, colons_pos - pair_start); - npy = raw_str.substr(colons_pos + 1, pair_end - colons_pos - 1); + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir, + test_desc["ifm_file"][i].get().c_str()); - // Empty strings can make it to here - if (name.length() == 0 || npy.length() == 0) - break; + DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename); - input_tensor_map[name] = npy; + if (tensor->allocate()) + { + WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + return 1; + } - pair_start = pair_end + 1; // skip colons - } while (!last_pair); + if (tensor->readFromNpyFile(filename)) + { + WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename); + tensor->dumpTensorParams(g_func_debug.func_debug_file); + return 1; + } - if ((size_t)tensorCount != input_tensor_map.size()) + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + gt.addToNextNodeList(gn); + } + } + } + } + catch (nlohmann::json::type_error& e) { - WARNING("graph has %lu input placeholders, but %lu initialized", tensorCount, input_tensor_map.size()); + WARNING("Fail accessing test descriptor: %s", e.what()); return 1; } - for (auto& tensor_pair : input_tensor_map) + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) { - tensor = gt.getInputTensorByName(tensor_pair.first); - if (!tensor) - { - WARNING("Unable to find input tensor %s", tensor_pair.first.c_str()); - return 1; - } + gt.dumpNextNodeList(g_func_debug.func_debug_file); + } - snprintf(filename, sizeof(filename), "%s/%s", g_func_config.input_dir, tensor_pair.second.c_str()); + return 0; +} - DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename); +int writeFinalTensors(SubgraphTraverser& gt, json test_desc) +{ + int tensorCount = gt.getNumOutputTensors(); + const Tensor* tensor; + char filename[1024]; - if (tensor->allocate()) + try + { + if ((tensorCount != (int)test_desc["ofm_name"].size()) || (tensorCount != (int)test_desc["ofm_file"].size())) { - WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + WARNING("Number of output tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount, + test_desc["ofm_name"].size(), test_desc["ofm_file"].size()); return 1; } - if (tensor->readFromNpyFile(filename)) + for (int i = 0; i < tensorCount; i++) { - WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename); - tensor->dumpTensorParams(g_func_debug.func_debug_file); - return 1; - } + tensor = gt.getOutputTensorByName(test_desc["ofm_name"][i].get()); + if (!tensor) + { + WARNING("Unable to find output tensor %s", test_desc["ofm_name"][i].get().c_str()); + return 1; + } - // Push ready consumers to the next node list - for (auto gn : tensor->getConsumers()) - { - if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir, + test_desc["ofm_file"][i].get().c_str()); + + DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + + if (tensor->writeToNpyFile(filename)) { - gt.addToNextNodeList(gn); + WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + return 1; } } } - - if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + catch (nlohmann::json::type_error& e) { - gt.dumpNextNodeList(g_func_debug.func_debug_file); + WARNING("Fail accessing test descriptor: %s", e.what()); + return 1; } return 0; } -int writeFinalTensors(SubgraphTraverser& gt) +// Read "foo,bar,..." and return std::vector({foo, bar, ...}) +std::vector parseFromString(std::string raw_str) { - int tensorCount = gt.getNumOutputTensors(); - const Tensor* tensor; - char filename[1024]; + bool last_pair = false; + std::string::size_type start = 0, end; + std::string name; - for (int i = 0; i < tensorCount; i++) + std::vector result; + do { - tensor = gt.getOutputTensor(i); - if (!tensor) - { - WARNING("Unable to find output tensor[%d]", i); - return 1; - } + end = raw_str.find(',', start); + if (end == std::string::npos) + last_pair = true; + + name = raw_str.substr(start, end); - snprintf(filename, sizeof(filename), "%s/%s%s.npy", g_func_config.output_dir, - g_func_config.output_tensor_prefix, tensor->getName().c_str()); + result.push_back(name); - DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + start = end + 1; // skip comma + } while (!last_pair); + + return result; +} + +int initTestDesc(json& test_desc) +{ + std::ifstream ifs(g_func_config.test_desc); - if (tensor->writeToNpyFile(filename)) + if (ifs.good()) + { + try + { + test_desc = nlohmann::json::parse(ifs); + } + catch (nlohmann::json::parse_error& e) { - WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + WARNING("Error parsing test descriptor json: %s", e.what()); return 1; } } + // Overwrite g_func_config.flatbuffer_dir with dirname(g_func_config.test_desc) if it's not specified. + std::string flatbuffer_dir_str(g_func_config.flatbuffer_dir); + if (flatbuffer_dir_str.empty()) + { + std::string test_path(g_func_config.test_desc); + std::string test_dir = test_path.substr(0, test_path.find_last_of("/\\")); + strncpy(g_func_config.flatbuffer_dir, test_dir.c_str(), 1024); + } + + // Overwrite test_desc["tosa_file"] if -Ctosa_file= specified. + std::string tosa_file_str(g_func_config.tosa_file); + if (!tosa_file_str.empty()) + { + test_desc["tosa_file"] = tosa_file_str; + } + + // Overwrite test_desc["ifm_name"] if -Cifm_name= specified. + std::string ifm_name_str(g_func_config.ifm_name); + if (!ifm_name_str.empty()) + { + std::vector ifm_name_vec = parseFromString(ifm_name_str); + test_desc["ifm_name"] = ifm_name_vec; + } + + // Overwrite test_desc["ifm_file"] if -Cifm_file= specified. + std::string ifm_file_str(g_func_config.ifm_file); + if (!ifm_file_str.empty()) + { + std::vector ifm_file_vec = parseFromString(ifm_file_str); + test_desc["ifm_file"] = ifm_file_vec; + } + + // Overwrite test_desc["ofm_name"] if -Cofm_name= specified. + std::string ofm_name_str(g_func_config.ofm_name); + if (!ofm_name_str.empty()) + { + std::vector ofm_name_vec = parseFromString(ofm_name_str); + test_desc["ofm_name"] = ofm_name_vec; + } + + // Overwrite test_desc["ofm_file"] if -Cofm_file= specified. + std::string ofm_file_str(g_func_config.ofm_file); + if (!ofm_file_str.empty()) + { + std::vector ofm_file_vec = parseFromString(ofm_file_str); + test_desc["ofm_file"] = ofm_file_vec; + } + return 0; } -- cgit v1.2.1