aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/main.cpp
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-07 20:49:09 +0000
committerEric Kunze <eric.kunze@arm.com>2023-09-15 18:10:01 +0000
commitcf84bc9cccbd5dc2fceae1a81c579e41be3c9a06 (patch)
treeaff6bab02c36c095a62381ac8f68d185bdccbe73 /reference_model/src/main.cpp
parent00f55bf46fe36bebe44e1365becbeb1e0d9e90c9 (diff)
downloadreference_model-cf84bc9cccbd5dc2fceae1a81c579e41be3c9a06.tar.gz
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
Diffstat (limited to 'reference_model/src/main.cpp')
-rw-r--r--reference_model/src/main.cpp152
1 files changed, 151 insertions, 1 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index cb7f0a2..62b8f6f 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -33,8 +33,11 @@ using namespace tosa;
using json = nlohmann::json;
int initTestDesc(json& test_desc);
+
int readInputTensors(SubgraphTraverser& gt, json& test_desc);
int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix);
+int readVariableTensors(SubgraphTraverser& gt, json test_desc);
+int writeVariableTensors(SubgraphTraverser& gt, json test_desc);
int loadGraph(TosaSerializationHandler& tsh, json& test_desc);
void parse_value(const std::string& text, tosa_level_t& value);
const std::string getResultFilenamePrefix();
@@ -131,6 +134,14 @@ int main(int argc, char** argv)
goto done;
}
+ if (g_func_config.initialize_variable_tensor_from_numpy)
+ {
+ if (readVariableTensors(main_gt, test_desc))
+ {
+ FATAL_ERROR("Unable to read variable tensors");
+ }
+ }
+
// evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
if (main_gt.evaluateAll())
{
@@ -177,6 +188,11 @@ int main(int argc, char** argv)
{
WARNING("Errors encountered in saving output tensors");
}
+
+ if (writeVariableTensors(main_gt, test_desc))
+ {
+ WARNING("Errors encountered in writing variable tensors");
+ }
}
}
@@ -312,7 +328,7 @@ int readInputTensors(SubgraphTraverser& gt, json& test_desc)
// Push ready consumers to the next node list
for (auto gn : tensor->getConsumers())
{
- if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
{
gt.addToNextNodeList(gn);
}
@@ -395,6 +411,124 @@ int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string&
return 0;
}
+int readVariableTensors(SubgraphTraverser& gt, json test_desc)
+{
+ int tensorCount = gt.getNumVariableTensors();
+ Tensor* tensor;
+ char filename[1024];
+
+ try
+ {
+ if ((tensorCount != (int)test_desc["variable_name"].size()) ||
+ (tensorCount != (int)test_desc["variable_file"].size()))
+ {
+ WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount,
+ test_desc["variable_name"].size(), test_desc["variable_file"].size());
+ return 1;
+ }
+
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get<std::string>());
+ if (!tensor)
+ {
+ WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get<std::string>().c_str());
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir.c_str(),
+ test_desc["variable_file"][i].get<std::string>().c_str());
+
+ DEBUG_MED(GT, "Loading variable tensor %s from filename: %s", tensor->getName().c_str(), filename);
+
+ if (!tensor->is_allocated())
+ {
+ WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->readFromNpyFile(filename))
+ {
+ WARNING("Unable to read variable tensor %s from filename: %s", tensor->getName().c_str(), filename);
+ tensor->dumpTensorParams(g_func_debug.func_debug_file);
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+ }
+ catch (nlohmann::json::type_error& e)
+ {
+ WARNING("Fail accessing test descriptor: %s", e.what());
+ return 1;
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ gt.dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int writeVariableTensors(SubgraphTraverser& gt, json test_desc)
+{
+ int tensorCount = gt.getNumVariableTensors();
+ const Tensor* tensor;
+ char filename[1024];
+
+ try
+ {
+ if ((tensorCount != (int)test_desc["variable_name"].size()) ||
+ (tensorCount != (int)test_desc["variable_file"].size()))
+ {
+ WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount,
+ test_desc["variable_name"].size(), test_desc["variable_file"].size());
+ return 1;
+ }
+
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get<std::string>());
+ if (!tensor)
+ {
+ WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get<std::string>().c_str());
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir.c_str(),
+ test_desc["variable_file"][i].get<std::string>().c_str());
+
+ DEBUG_MED(GT, "Writing variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ if (!tensor->is_allocated())
+ {
+ WARNING("Tensor %s is no longer allocated", tensor->getName().c_str());
+ return 1;
+ }
+ if (tensor->writeToNpyFile(filename))
+ {
+ WARNING("Unable to write variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(),
+ filename);
+ return 1;
+ }
+ }
+ }
+ catch (nlohmann::json::type_error& e)
+ {
+ WARNING("Fail accessing test descriptor: %s", e.what());
+ return 1;
+ }
+
+ return 0;
+}
+
// Read "foo,bar,..." and return std::vector({foo, bar, ...})
std::vector<std::string> parseFromString(std::string raw_str)
{
@@ -489,6 +623,22 @@ int initTestDesc(json& test_desc)
test_desc["ofm_file"] = ofm_file_vec;
}
+ // Overwrite test_desc["variable_name"] if --variable_name= specified.
+ std::string variable_name_str(g_func_config.variable_name);
+ if (!variable_name_str.empty())
+ {
+ std::vector<std::string> variable_name_vec = parseFromString(variable_name_str);
+ test_desc["variable_name"] = variable_name_vec;
+ }
+
+ // Overwrite test_desc["variable_file"] if --variable_file= specified.
+ std::string variable_file_str(g_func_config.variable_file);
+ if (!variable_file_str.empty())
+ {
+ std::vector<std::string> variable_file_vec = parseFromString(variable_file_str);
+ test_desc["variable_file"] = variable_file_vec;
+ }
+
return 0;
}