aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/main.cpp
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-03 15:00:34 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-06-04 16:38:40 -0700
commitcd79f0e06bf53c2c0fee39ee916bb6d79f177b57 (patch)
tree367078aeef8fd376711abfe6e52de7bfe491e527 /reference_model/src/main.cpp
parent571f7182a10a974f1ce993d83b01070153f142cc (diff)
downloadreference_model-cd79f0e06bf53c2c0fee39ee916bb6d79f177b57.tar.gz
Rewrite model frontend to be json-driven.
Change-Id: Iac786eff96183938d2fd11cde9313c6e8e1270a5
Diffstat (limited to 'reference_model/src/main.cpp')
-rw-r--r--reference_model/src/main.cpp258
1 files changed, 182 insertions, 76 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 240d913..0d6d8a3 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -22,16 +22,21 @@
#include <Eigen/CXX11/Tensor>
#include <iostream>
+#include <fstream>
+#include <nlohmann/json.hpp>
+
using namespace TosaReference;
using namespace tosa;
+using json = nlohmann::json;
// Global instantiation of configuration and debug objects
func_config_t g_func_config;
func_debug_t g_func_debug;
-int readInputTensors(SubgraphTraverser& gt);
-int writeFinalTensors(SubgraphTraverser& gt);
-int loadGraph(TosaSerializationHandler& tsh);
+int initTestDesc(json& test_desc);
+int readInputTensors(SubgraphTraverser& gt, json test_desc);
+int writeFinalTensors(SubgraphTraverser& gt, json test_desc);
+int loadGraph(TosaSerializationHandler& tsh, json test_desc);
int main(int argc, const char** argv)
{
@@ -46,7 +51,15 @@ int main(int argc, const char** argv)
return 1;
}
- if (loadGraph(tsh))
+ json test_desc;
+
+ // Initialize test descriptor
+ if (initTestDesc(test_desc))
+ {
+ SIMPLE_FATAL_ERROR("Unable to load test json");
+ }
+
+ if (loadGraph(tsh, test_desc))
{
SIMPLE_FATAL_ERROR("Unable to load graph");
}
@@ -74,7 +87,7 @@ int main(int argc, const char** argv)
goto done;
}
- if (readInputTensors(main_gt))
+ if (readInputTensors(main_gt, test_desc))
{
SIMPLE_FATAL_ERROR("Unable to read input tensors");
}
@@ -113,7 +126,7 @@ int main(int argc, const char** argv)
if (g_func_config.output_tensors)
{
- if (writeFinalTensors(main_gt))
+ if (writeFinalTensors(main_gt, test_desc))
{
WARNING("Errors encountered in saving output tensors");
}
@@ -127,16 +140,17 @@ done:
return 0;
}
-int loadGraph(TosaSerializationHandler& tsh)
+int loadGraph(TosaSerializationHandler& tsh, json test_desc)
{
char graph_fullname[1024];
- snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.subgraph_dir, g_func_config.subgraph_file);
+ snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.flatbuffer_dir,
+ test_desc["tosa_file"].get<std::string>().c_str());
if (strlen(graph_fullname) <= 2)
{
func_model_print_help(stderr);
- SIMPLE_FATAL_ERROR("Missing required argument: Check -Csubgraph_file=");
+ SIMPLE_FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc=");
}
const char JSON_EXT[] = ".json";
@@ -163,131 +177,223 @@ int loadGraph(TosaSerializationHandler& tsh)
if (tsh.LoadFileJson(graph_fullname))
{
- SIMPLE_FATAL_ERROR("\nError loading JSON graph file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=",
- graph_fullname);
+ SIMPLE_FATAL_ERROR(
+ "\nError loading JSON graph file: %s\nCheck -Ctest_desc=, -Ctosa_file= and -Cflatbuffer_dir=",
+ graph_fullname);
}
}
else
{
if (tsh.LoadFileTosaFlatbuffer(graph_fullname))
{
- SIMPLE_FATAL_ERROR("\nError loading TOSA flatbuffer file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=",
- graph_fullname);
+ SIMPLE_FATAL_ERROR(
+ "\nError loading TOSA flatbuffer file: %s\nCheck -Ctest_desc=, -Ctosa_file= and -Cflatbuffer_dir=",
+ graph_fullname);
}
}
return 0;
}
-int readInputTensors(SubgraphTraverser& gt)
+int readInputTensors(SubgraphTraverser& gt, json test_desc)
{
int tensorCount = gt.getNumInputTensors();
Tensor* tensor;
char filename[1024];
- // assuming filename doesn't have colons(:)
- std::map<std::string, std::string> input_tensor_map;
- std::string raw_str(g_func_config.input_tensor);
- std::string name, npy;
- bool last_pair = false;
-
- std::string::size_type pair_start = 0, pair_end, colons_pos;
- do
+ try
{
- pair_end = raw_str.find(',', pair_start);
- if (pair_end == std::string::npos)
- last_pair = true;
+ if ((tensorCount != (int)test_desc["ifm_name"].size()) || (tensorCount != (int)test_desc["ifm_file"].size()))
+ {
+ WARNING("Number of input tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount,
+ test_desc["ifm_name"].size(), test_desc["ifm_file"].size());
+ return 1;
+ }
- colons_pos = raw_str.find(':', pair_start);
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getInputTensorByName(test_desc["ifm_name"][i].get<std::string>());
+ if (!tensor)
+ {
+ WARNING("Unable to find input tensor %s", test_desc["ifm_name"][i].get<std::string>().c_str());
+ return 1;
+ }
- name = raw_str.substr(pair_start, colons_pos - pair_start);
- npy = raw_str.substr(colons_pos + 1, pair_end - colons_pos - 1);
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir,
+ test_desc["ifm_file"][i].get<std::string>().c_str());
- // Empty strings can make it to here
- if (name.length() == 0 || npy.length() == 0)
- break;
+ DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename);
- input_tensor_map[name] = npy;
+ if (tensor->allocate())
+ {
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
+ }
- pair_start = pair_end + 1; // skip colons
- } while (!last_pair);
+ if (tensor->readFromNpyFile(filename))
+ {
+ WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename);
+ tensor->dumpTensorParams(g_func_debug.func_debug_file);
+ return 1;
+ }
- if ((size_t)tensorCount != input_tensor_map.size())
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+ }
+ catch (nlohmann::json::type_error& e)
{
- WARNING("graph has %lu input placeholders, but %lu initialized", tensorCount, input_tensor_map.size());
+ WARNING("Fail accessing test descriptor: %s", e.what());
return 1;
}
- for (auto& tensor_pair : input_tensor_map)
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
{
- tensor = gt.getInputTensorByName(tensor_pair.first);
- if (!tensor)
- {
- WARNING("Unable to find input tensor %s", tensor_pair.first.c_str());
- return 1;
- }
+ gt.dumpNextNodeList(g_func_debug.func_debug_file);
+ }
- snprintf(filename, sizeof(filename), "%s/%s", g_func_config.input_dir, tensor_pair.second.c_str());
+ return 0;
+}
- DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename);
+int writeFinalTensors(SubgraphTraverser& gt, json test_desc)
+{
+ int tensorCount = gt.getNumOutputTensors();
+ const Tensor* tensor;
+ char filename[1024];
- if (tensor->allocate())
+ try
+ {
+ if ((tensorCount != (int)test_desc["ofm_name"].size()) || (tensorCount != (int)test_desc["ofm_file"].size()))
{
- WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ WARNING("Number of output tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount,
+ test_desc["ofm_name"].size(), test_desc["ofm_file"].size());
return 1;
}
- if (tensor->readFromNpyFile(filename))
+ for (int i = 0; i < tensorCount; i++)
{
- WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename);
- tensor->dumpTensorParams(g_func_debug.func_debug_file);
- return 1;
- }
+ tensor = gt.getOutputTensorByName(test_desc["ofm_name"][i].get<std::string>());
+ if (!tensor)
+ {
+ WARNING("Unable to find output tensor %s", test_desc["ofm_name"][i].get<std::string>().c_str());
+ return 1;
+ }
- // Push ready consumers to the next node list
- for (auto gn : tensor->getConsumers())
- {
- if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir,
+ test_desc["ofm_file"][i].get<std::string>().c_str());
+
+ DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+
+ if (tensor->writeToNpyFile(filename))
{
- gt.addToNextNodeList(gn);
+ WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ return 1;
}
}
}
-
- if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ catch (nlohmann::json::type_error& e)
{
- gt.dumpNextNodeList(g_func_debug.func_debug_file);
+ WARNING("Fail accessing test descriptor: %s", e.what());
+ return 1;
}
return 0;
}
-int writeFinalTensors(SubgraphTraverser& gt)
+// Read "foo,bar,..." and return std::vector({foo, bar, ...})
+std::vector<std::string> parseFromString(std::string raw_str)
{
- int tensorCount = gt.getNumOutputTensors();
- const Tensor* tensor;
- char filename[1024];
+ bool last_pair = false;
+ std::string::size_type start = 0, end;
+ std::string name;
- for (int i = 0; i < tensorCount; i++)
+ std::vector<std::string> result;
+ do
{
- tensor = gt.getOutputTensor(i);
- if (!tensor)
- {
- WARNING("Unable to find output tensor[%d]", i);
- return 1;
- }
+ end = raw_str.find(',', start);
+ if (end == std::string::npos)
+ last_pair = true;
+
+ name = raw_str.substr(start, end);
- snprintf(filename, sizeof(filename), "%s/%s%s.npy", g_func_config.output_dir,
- g_func_config.output_tensor_prefix, tensor->getName().c_str());
+ result.push_back(name);
- DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ start = end + 1; // skip comma
+ } while (!last_pair);
+
+ return result;
+}
+
+int initTestDesc(json& test_desc)
+{
+ std::ifstream ifs(g_func_config.test_desc);
- if (tensor->writeToNpyFile(filename))
+ if (ifs.good())
+ {
+ try
+ {
+ test_desc = nlohmann::json::parse(ifs);
+ }
+ catch (nlohmann::json::parse_error& e)
{
- WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ WARNING("Error parsing test descriptor json: %s", e.what());
return 1;
}
}
+ // Overwrite g_func_config.flatbuffer_dir with dirname(g_func_config.test_desc) if it's not specified.
+ std::string flatbuffer_dir_str(g_func_config.flatbuffer_dir);
+ if (flatbuffer_dir_str.empty())
+ {
+ std::string test_path(g_func_config.test_desc);
+ std::string test_dir = test_path.substr(0, test_path.find_last_of("/\\"));
+ strncpy(g_func_config.flatbuffer_dir, test_dir.c_str(), 1024);
+ }
+
+ // Overwrite test_desc["tosa_file"] if -Ctosa_file= specified.
+ std::string tosa_file_str(g_func_config.tosa_file);
+ if (!tosa_file_str.empty())
+ {
+ test_desc["tosa_file"] = tosa_file_str;
+ }
+
+ // Overwrite test_desc["ifm_name"] if -Cifm_name= specified.
+ std::string ifm_name_str(g_func_config.ifm_name);
+ if (!ifm_name_str.empty())
+ {
+ std::vector<std::string> ifm_name_vec = parseFromString(ifm_name_str);
+ test_desc["ifm_name"] = ifm_name_vec;
+ }
+
+ // Overwrite test_desc["ifm_file"] if -Cifm_file= specified.
+ std::string ifm_file_str(g_func_config.ifm_file);
+ if (!ifm_file_str.empty())
+ {
+ std::vector<std::string> ifm_file_vec = parseFromString(ifm_file_str);
+ test_desc["ifm_file"] = ifm_file_vec;
+ }
+
+ // Overwrite test_desc["ofm_name"] if -Cofm_name= specified.
+ std::string ofm_name_str(g_func_config.ofm_name);
+ if (!ofm_name_str.empty())
+ {
+ std::vector<std::string> ofm_name_vec = parseFromString(ofm_name_str);
+ test_desc["ofm_name"] = ofm_name_vec;
+ }
+
+ // Overwrite test_desc["ofm_file"] if -Cofm_file= specified.
+ std::string ofm_file_str(g_func_config.ofm_file);
+ if (!ofm_file_str.empty())
+ {
+ std::vector<std::string> ofm_file_vec = parseFromString(ofm_file_str);
+ test_desc["ofm_file"] = ofm_file_vec;
+ }
+
return 0;
}