From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- reference_model/src/main.cpp | 295 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 reference_model/src/main.cpp (limited to 'reference_model/src/main.cpp') diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp new file mode 100644 index 0000000..ec2fdc9 --- /dev/null +++ b/reference_model/src/main.cpp @@ -0,0 +1,295 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "model_common.h" +#include "ops/op_factory.h" +#include "subgraph_traverser.h" +#include "tosa_serialization_handler.h" +#include +#include + +using namespace TosaReference; +using namespace tosa; + +// Global instantiation of configuration and debug objects +func_config_t g_func_config; +func_debug_t g_func_debug; + +int readInputTensors(SubgraphTraverser& gt); +int writeFinalTensors(SubgraphTraverser& gt); +int loadGraph(TosaSerializationHandler& tsh); + +int main(int argc, const char** argv) +{ + // Initialize configuration and debug subsystems + func_model_init_config(); + func_model_set_default_config(&g_func_config); + func_init_debug(&g_func_debug, 0); + TosaSerializationHandler tsh; + + if (func_model_parse_cmd_line(&g_func_config, &g_func_debug, argc, argv)) + { + return 1; + } + + if (loadGraph(tsh)) + { + SIMPLE_FATAL_ERROR("Unable to load graph"); + } + + // load json first since it's easier debugging + SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh); + + if (main_gt.initializeGraph()) + { + SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\""); + } + + if (main_gt.linkTensorsAndNodes()) + { + SIMPLE_FATAL_ERROR("Failed to link tensors and nodes"); + } + + if (main_gt.validateGraph()) + { + SIMPLE_FATAL_ERROR("Failed to validate graph"); + } + + if (g_func_config.validate_only) + { + goto done; + } + + if (readInputTensors(main_gt)) + { + SIMPLE_FATAL_ERROR("Unable to read input tensors"); + } + + if (g_func_config.eval) + { + + if (main_gt.evaluateAll()) + { + SIMPLE_FATAL_ERROR("Error evaluating network. Giving up."); + } + + // make sure output tensor is evaluated and show its value + int num_output_tensors = main_gt.getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const Tensor* ct = main_gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + main_gt.dumpGraph(g_func_debug.func_debug_file); + SIMPLE_FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } + + if (g_func_config.output_tensors) + { + if (writeFinalTensors(main_gt)) + { + WARNING("Errors encountered in saving output tensors"); + } + } + } + +done: + func_fini_debug(&g_func_debug); + func_model_config_cleanup(); + + return 0; +} + +int loadGraph(TosaSerializationHandler& tsh) +{ + char graph_fullname[1024]; + + snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.subgraph_dir, g_func_config.subgraph_file); + + if (strlen(graph_fullname) <= 2) + { + func_model_print_help(stderr); + SIMPLE_FATAL_ERROR("Missing required argument: Check -Csubgraph_file="); + } + + const char JSON_EXT[] = ".json"; + int is_json = 0; + { + // look for JSON file extension + size_t suffix_len = strlen(JSON_EXT); + size_t str_len = strlen(graph_fullname); + + if (str_len > suffix_len && strncasecmp(graph_fullname + (str_len - suffix_len), JSON_EXT, suffix_len) == 0) + { + is_json = 1; + } + } + + if (is_json) + { + if (tsh.LoadFileSchema(g_func_config.operator_fbs)) + { + SIMPLE_FATAL_ERROR( + "\nJSON file detected. Unable to load TOSA flatbuffer schema from: %s\nCheck -Coperator_fbs=", + g_func_config.operator_fbs); + } + + if (tsh.LoadFileJson(graph_fullname)) + { + SIMPLE_FATAL_ERROR("\nError loading JSON graph file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=", + graph_fullname); + } + } + else + { + if (tsh.LoadFileTosaFlatbuffer(graph_fullname)) + { + SIMPLE_FATAL_ERROR("\nError loading TOSA flatbuffer file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=", + graph_fullname); + } + } + + return 0; +} + +int readInputTensors(SubgraphTraverser& gt) +{ + int tensorCount = gt.getNumInputTensors(); + Tensor* tensor; + char filename[1024]; + + // assuming filename doesn't have colons(:) + std::map input_tensor_map; + std::string raw_str(g_func_config.input_tensor); + std::string name, npy; + bool last_pair = false; + + std::string::size_type pair_start = 0, pair_end, colons_pos; + do + { + pair_end = raw_str.find(',', pair_start); + if (pair_end == std::string::npos) + last_pair = true; + + colons_pos = raw_str.find(':', pair_start); + + name = raw_str.substr(pair_start, colons_pos - pair_start); + npy = raw_str.substr(colons_pos + 1, pair_end - colons_pos - 1); + + // Empty strings can make it to here + if (name.length() == 0 || npy.length() == 0) + break; + + input_tensor_map[name] = npy; + + pair_start = pair_end + 1; // skip colons + } while (!last_pair); + + if ((size_t)tensorCount != input_tensor_map.size()) + { + WARNING("graph has %lu input placeholders, but %lu initialized", tensorCount, input_tensor_map.size()); + return 1; + } + + for (auto& tensor_pair : input_tensor_map) + { + tensor = gt.getInputTensorByName(tensor_pair.first); + if (!tensor) + { + WARNING("Unable to find input tensor %s", tensor_pair.first.c_str()); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.input_dir, tensor_pair.second.c_str()); + + DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename); + + if (tensor->allocate()) + { + WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + return 1; + } + + if (tensor->readFromNpyFile(filename)) + { + WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename); + tensor->dumpTensorParams(g_func_debug.func_debug_file); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + gt.addToNextNodeList(gn); + } + } + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + gt.dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int writeFinalTensors(SubgraphTraverser& gt) +{ + int tensorCount = gt.getNumOutputTensors(); + const Tensor* tensor; + char filename[1024]; + + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getOutputTensor(i); + if (!tensor) + { + WARNING("Unable to find output tensor[%d]", i); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s%s.npy", g_func_config.output_dir, + g_func_config.output_tensor_prefix, tensor->getName().c_str()); + + DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + + if (tensor->writeToNpyFile(filename)) + { + WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + return 1; + } + } + + return 0; +} -- cgit v1.2.1