From ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Mon, 26 Sep 2022 13:31:43 +0100 Subject: Add IModelRunner interface to TOSA Reference Model * Added IModelRunner interface using pimpl idiom, which allows a user to initialize, configure and run the model. * Added unit tests for IModelRunner. * Added doctest as third-party submodule. * Added user options to specify paths for dependencies. * Moved general func_config functions to separate utility, which removes cxxopts dependency. Signed-off-by: Matthew Sloyan Change-Id: If42f1f82cd6dadf18911a48dcd5fa579b719aff2 --- reference_model/src/model_runner_impl.cc | 277 +++++++++++++++++++++++++++++++ 1 file changed, 277 insertions(+) create mode 100644 reference_model/src/model_runner_impl.cc (limited to 'reference_model/src/model_runner_impl.cc') diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc new file mode 100644 index 0000000..e0fdc49 --- /dev/null +++ b/reference_model/src/model_runner_impl.cc @@ -0,0 +1,277 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +ModelRunnerImpl::ModelRunnerImpl() +{} + +ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config, + const func_debug_t& func_debug) +{ + g_func_config = func_config; + g_func_debug = func_debug; +} + +ModelRunnerImpl::~ModelRunnerImpl() +{ + g_func_debug.fini_debug(); + delete _main_gt; +}; + +void ModelRunnerImpl::setFuncConfig(func_config_t& func_config) +{ + g_func_config = func_config; +} +void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) +{ + g_func_debug = func_debug; +} + +GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) +{ + validateTosaVersion(serialization_handler); + + // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. + _main_gt = nullptr; + _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler); + + if(_main_gt == nullptr) + { + WARNING("An error occurred when generating main graph traverser."); + return GraphStatus::TOSA_ERROR; + } + + if (_main_gt->initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->validateGraph()) + { + WARNING("Failed to validate graph."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->allocateTensor()) + { + WARNING("Failed to allocate tensor."); + return _main_gt->getGraphStatus(); + } + + return _main_gt->getGraphStatus(); +} + +GraphStatus ModelRunnerImpl::run() +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()"); + } + + if (g_func_config.validate_only) + { + goto done; + } + + // Validate the number of inputs matches the + if (static_cast(_main_gt->getNumInputTensors()) != n_input_tensors) + { + FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). " + "setInput() must be called for each input.", + n_input_tensors, _main_gt->getNumInputTensors()); + } + + if (g_func_config.eval) + { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. + if (_main_gt->evaluateAll()) + { + ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID || + _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); + } + + // Only generate output tensor if graph is valid. + if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID) + { + // Make sure output tensor is evaluated and show its value + int num_output_tensors = _main_gt->getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const Tensor* ct = _main_gt->getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + _main_gt->dumpGraph(g_func_debug.func_debug_file); + FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } + } + } + +done: + // Print status if not valid and do cleanup. + checkGraphStatus(*_main_gt); + g_func_debug.fini_debug(); + + return _main_gt->getGraphStatus(); +} + +template +int ModelRunnerImpl::setInput(std::string input_name, std::vector vals) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()"); + } + + Tensor* tensor; + tensor = _main_gt->getInputTensorByName(input_name); + + if (!tensor) + { + WARNING("Unable to find input tensor %s", input_name.c_str()); + return 1; + } + + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); + return 1; + } + + if (tensor->readfromVector(vals)) + { + WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str()); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + _main_gt->addToNextNodeList(gn); + } + } + + n_input_tensors++; + return 0; +} + +template +std::vector ModelRunnerImpl::getOutput(std::string output_name) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()"); + } + + Tensor* tensor; + tensor = _main_gt->getOutputTensorByName(output_name); + + if (!tensor) + { + WARNING("Unable to find output tensor %s", output_name.c_str()); + return std::vector(); + } + + std::vector outputs(tensor->getElementCount(), 0); + + if (tensor->writeToVector(outputs)) + { + WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); + return std::vector(); + } + + return outputs; +} + +void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler) +{ + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); + + TosaVersion::compat_t is_compat = model_version.is_compatible(serialization_handler.GetVersion()); + switch (is_compat) + { + case TosaVersion::compat_t::COMPLETELY_COMPATIBLE: + break; + case TosaVersion::compat_t::PARTIALLY_COMPATIBLE: + WARNING("Reference model version %s is partially compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + break; + case TosaVersion::compat_t::NOT_COMPATIBLE: + FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + } +} + +void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) +{ + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + WARNING("Graph result: UNPREDICTABLE."); + break; + case GraphStatus::TOSA_ERROR: + WARNING("Graph result: ERROR."); + break; + default: + WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus()); + } +} + +// Template explicit specialization +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); + +template std::vector ModelRunnerImpl::getOutput(std::string output_name); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); \ No newline at end of file -- cgit v1.2.1