From cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 7 Sep 2023 20:49:09 +0000 Subject: [reference_model] Support StatefulOps and the tests for CallOnceOp Signed-off-by: Jerry Ge Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2 --- reference_model/src/main.cpp | 152 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) (limited to 'reference_model/src/main.cpp') diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index cb7f0a2..62b8f6f 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -33,8 +33,11 @@ using namespace tosa; using json = nlohmann::json; int initTestDesc(json& test_desc); + int readInputTensors(SubgraphTraverser& gt, json& test_desc); int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix); +int readVariableTensors(SubgraphTraverser& gt, json test_desc); +int writeVariableTensors(SubgraphTraverser& gt, json test_desc); int loadGraph(TosaSerializationHandler& tsh, json& test_desc); void parse_value(const std::string& text, tosa_level_t& value); const std::string getResultFilenamePrefix(); @@ -131,6 +134,14 @@ int main(int argc, char** argv) goto done; } + if (g_func_config.initialize_variable_tensor_from_numpy) + { + if (readVariableTensors(main_gt, test_desc)) + { + FATAL_ERROR("Unable to read variable tensors"); + } + } + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. if (main_gt.evaluateAll()) { @@ -177,6 +188,11 @@ int main(int argc, char** argv) { WARNING("Errors encountered in saving output tensors"); } + + if (writeVariableTensors(main_gt, test_desc)) + { + WARNING("Errors encountered in writing variable tensors"); + } } } @@ -312,7 +328,7 @@ int readInputTensors(SubgraphTraverser& gt, json& test_desc) // Push ready consumers to the next node list for (auto gn : tensor->getConsumers()) { - if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) { gt.addToNextNodeList(gn); } @@ -395,6 +411,124 @@ int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& return 0; } +int readVariableTensors(SubgraphTraverser& gt, json test_desc) +{ + int tensorCount = gt.getNumVariableTensors(); + Tensor* tensor; + char filename[1024]; + + try + { + if ((tensorCount != (int)test_desc["variable_name"].size()) || + (tensorCount != (int)test_desc["variable_file"].size())) + { + WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount, + test_desc["variable_name"].size(), test_desc["variable_file"].size()); + return 1; + } + + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get()); + if (!tensor) + { + WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get().c_str()); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir.c_str(), + test_desc["variable_file"][i].get().c_str()); + + DEBUG_MED(GT, "Loading variable tensor %s from filename: %s", tensor->getName().c_str(), filename); + + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); + return 1; + } + + if (tensor->readFromNpyFile(filename)) + { + WARNING("Unable to read variable tensor %s from filename: %s", tensor->getName().c_str(), filename); + tensor->dumpTensorParams(g_func_debug.func_debug_file); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) + { + gt.addToNextNodeList(gn); + } + } + } + } + catch (nlohmann::json::type_error& e) + { + WARNING("Fail accessing test descriptor: %s", e.what()); + return 1; + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + gt.dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int writeVariableTensors(SubgraphTraverser& gt, json test_desc) +{ + int tensorCount = gt.getNumVariableTensors(); + const Tensor* tensor; + char filename[1024]; + + try + { + if ((tensorCount != (int)test_desc["variable_name"].size()) || + (tensorCount != (int)test_desc["variable_file"].size())) + { + WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount, + test_desc["variable_name"].size(), test_desc["variable_file"].size()); + return 1; + } + + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get()); + if (!tensor) + { + WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get().c_str()); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir.c_str(), + test_desc["variable_file"][i].get().c_str()); + + DEBUG_MED(GT, "Writing variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is no longer allocated", tensor->getName().c_str()); + return 1; + } + if (tensor->writeToNpyFile(filename)) + { + WARNING("Unable to write variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), + filename); + return 1; + } + } + } + catch (nlohmann::json::type_error& e) + { + WARNING("Fail accessing test descriptor: %s", e.what()); + return 1; + } + + return 0; +} + // Read "foo,bar,..." and return std::vector({foo, bar, ...}) std::vector parseFromString(std::string raw_str) { @@ -489,6 +623,22 @@ int initTestDesc(json& test_desc) test_desc["ofm_file"] = ofm_file_vec; } + // Overwrite test_desc["variable_name"] if --variable_name= specified. + std::string variable_name_str(g_func_config.variable_name); + if (!variable_name_str.empty()) + { + std::vector variable_name_vec = parseFromString(variable_name_str); + test_desc["variable_name"] = variable_name_vec; + } + + // Overwrite test_desc["variable_file"] if --variable_file= specified. + std::string variable_file_str(g_func_config.variable_file); + if (!variable_file_str.empty()) + { + std::vector variable_file_vec = parseFromString(variable_file_str); + test_desc["variable_file"] = variable_file_vec; + } + return 0; } -- cgit v1.2.1