aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/model_runner_impl.cc
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2022-09-26 13:31:43 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2022-10-07 16:02:07 +0100
commitba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (patch)
tree10e3e127c091da90d591253fd55e8566c0b61e7a /reference_model/src/model_runner_impl.cc
parenta0848c6edbf37034e280a670bdd2f990fdf796da (diff)
downloadreference_model-ba5fad356a926d5e1c6e0fe6b546a310230cc5a8.tar.gz
Add IModelRunner interface to TOSA Reference Model
* Added IModelRunner interface using pimpl idiom, which allows a user to initialize, configure and run the model. * Added unit tests for IModelRunner. * Added doctest as third-party submodule. * Added user options to specify paths for dependencies. * Moved general func_config functions to separate utility, which removes cxxopts dependency. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: If42f1f82cd6dadf18911a48dcd5fa579b719aff2
Diffstat (limited to 'reference_model/src/model_runner_impl.cc')
-rw-r--r--reference_model/src/model_runner_impl.cc277
1 files changed, 277 insertions, 0 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
new file mode 100644
index 0000000..e0fdc49
--- /dev/null
+++ b/reference_model/src/model_runner_impl.cc
@@ -0,0 +1,277 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "model_runner_impl.h"
+
+using namespace TosaReference;
+
+ModelRunnerImpl::ModelRunnerImpl()
+{}
+
+ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config,
+ const func_debug_t& func_debug)
+{
+ g_func_config = func_config;
+ g_func_debug = func_debug;
+}
+
+ModelRunnerImpl::~ModelRunnerImpl()
+{
+ g_func_debug.fini_debug();
+ delete _main_gt;
+};
+
+void ModelRunnerImpl::setFuncConfig(func_config_t& func_config)
+{
+ g_func_config = func_config;
+}
+void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug)
+{
+ g_func_debug = func_debug;
+}
+
+GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler)
+{
+ validateTosaVersion(serialization_handler);
+
+ // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
+ _main_gt = nullptr;
+ _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler);
+
+ if(_main_gt == nullptr)
+ {
+ WARNING("An error occurred when generating main graph traverser.");
+ return GraphStatus::TOSA_ERROR;
+ }
+
+ if (_main_gt->initializeGraph())
+ {
+ WARNING("Unable to initialize main graph traverser.");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->linkTensorsAndNodes())
+ {
+ WARNING("Failed to link tensors and nodes");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->validateGraph())
+ {
+ WARNING("Failed to validate graph.");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->allocateTensor())
+ {
+ WARNING("Failed to allocate tensor.");
+ return _main_gt->getGraphStatus();
+ }
+
+ return _main_gt->getGraphStatus();
+}
+
+GraphStatus ModelRunnerImpl::run()
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()");
+ }
+
+ if (g_func_config.validate_only)
+ {
+ goto done;
+ }
+
+ // Validate the number of inputs matches the
+ if (static_cast<uint32_t>(_main_gt->getNumInputTensors()) != n_input_tensors)
+ {
+ FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). "
+ "setInput() must be called for each input.",
+ n_input_tensors, _main_gt->getNumInputTensors());
+ }
+
+ if (g_func_config.eval)
+ {
+ // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
+ if (_main_gt->evaluateAll())
+ {
+ ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID,
+ "Upon evaluateAll() returning 1, graph can not be VALID.");
+ }
+ else
+ {
+ ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID ||
+ _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
+ "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
+ }
+
+ // Only generate output tensor if graph is valid.
+ if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID)
+ {
+ // Make sure output tensor is evaluated and show its value
+ int num_output_tensors = _main_gt->getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const Tensor* ct = _main_gt->getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ _main_gt->dumpGraph(g_func_debug.func_debug_file);
+ FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
+ }
+ }
+
+done:
+ // Print status if not valid and do cleanup.
+ checkGraphStatus(*_main_gt);
+ g_func_debug.fini_debug();
+
+ return _main_gt->getGraphStatus();
+}
+
+template <typename T>
+int ModelRunnerImpl::setInput(std::string input_name, std::vector<T> vals)
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()");
+ }
+
+ Tensor* tensor;
+ tensor = _main_gt->getInputTensorByName(input_name);
+
+ if (!tensor)
+ {
+ WARNING("Unable to find input tensor %s", input_name.c_str());
+ return 1;
+ }
+
+ if (!tensor->is_allocated())
+ {
+ WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->readfromVector(vals))
+ {
+ WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str());
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ _main_gt->addToNextNodeList(gn);
+ }
+ }
+
+ n_input_tensors++;
+ return 0;
+}
+
+template <typename T>
+std::vector<T> ModelRunnerImpl::getOutput(std::string output_name)
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()");
+ }
+
+ Tensor* tensor;
+ tensor = _main_gt->getOutputTensorByName(output_name);
+
+ if (!tensor)
+ {
+ WARNING("Unable to find output tensor %s", output_name.c_str());
+ return std::vector<T>();
+ }
+
+ std::vector<T> outputs(tensor->getElementCount(), 0);
+
+ if (tensor->writeToVector(outputs))
+ {
+ WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str());
+ return std::vector<T>();
+ }
+
+ return outputs;
+}
+
+void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler)
+{
+ TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR,
+ TOSA_REFERENCE_MODEL_VERSION_MINOR,
+ TOSA_REFERENCE_MODEL_VERSION_PATCH,
+ TOSA_REFERENCE_MODEL_VERSION_DRAFT);
+
+ TosaVersion::compat_t is_compat = model_version.is_compatible(serialization_handler.GetVersion());
+ switch (is_compat)
+ {
+ case TosaVersion::compat_t::COMPLETELY_COMPATIBLE:
+ break;
+ case TosaVersion::compat_t::PARTIALLY_COMPATIBLE:
+ WARNING("Reference model version %s is partially compatible with serializer version %s.",
+ model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str());
+ break;
+ case TosaVersion::compat_t::NOT_COMPATIBLE:
+ FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.",
+ model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str());
+ }
+}
+
+void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt)
+{
+ switch (main_gt.getGraphStatus())
+ {
+ case GraphStatus::TOSA_VALID:
+ // Result is valid.
+ break;
+ case GraphStatus::TOSA_UNPREDICTABLE:
+ WARNING("Graph result: UNPREDICTABLE.");
+ break;
+ case GraphStatus::TOSA_ERROR:
+ WARNING("Graph result: ERROR.");
+ break;
+ default:
+ WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus());
+ }
+}
+
+// Template explicit specialization
+template int ModelRunnerImpl::setInput<float>(std::string input_name, std::vector<float> vals);
+template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals);
+template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals);
+template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals);
+
+template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name);
+template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name);
+template std::vector<int64_t> ModelRunnerImpl::getOutput<int64_t>(std::string output_name);
+template std::vector<unsigned char> ModelRunnerImpl::getOutput<unsigned char>(std::string output_name); \ No newline at end of file