From 64285a1f25e2c7b85ed1f00b7947403e92baea00 Mon Sep 17 00:00:00 2001 From: Grant Watson Date: Wed, 16 Nov 2022 15:32:39 +0000 Subject: Extend reference model API with eager operator execution entrypoints - Adds a script to generate operators.h and operators.cc - Adds jinja2 templates for generating operators.h and operators.cc - Adds unit tests for a subset of the operators generated - Includes the TOSA specification as a submodule - Adds supporting C++ and header files Signed-off-by: Grant Watson Change-Id: I5b60db4c56113110d8e75fe1152525d258233f9c --- reference_model/src/model_runner_impl.cc | 178 +++++++++++++++++++++++-------- 1 file changed, 136 insertions(+), 42 deletions(-) (limited to 'reference_model/src/model_runner_impl.cc') diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index 8427150..1109dd6 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -45,42 +45,12 @@ void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) { validateTosaVersion(serialization_handler); + return initialize(serialization_handler.GetMainBlock(), &serialization_handler); +} - // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. - _main_gt = nullptr; - _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler); - - if(_main_gt == nullptr) - { - WARNING("An error occurred when generating main graph traverser."); - return GraphStatus::TOSA_ERROR; - } - - if (_main_gt->initializeGraph()) - { - WARNING("Unable to initialize main graph traverser."); - return _main_gt->getGraphStatus(); - } - - if (_main_gt->linkTensorsAndNodes()) - { - WARNING("Failed to link tensors and nodes"); - return _main_gt->getGraphStatus(); - } - - if (_main_gt->validateGraph()) - { - WARNING("Failed to validate graph."); - return _main_gt->getGraphStatus(); - } - - if (_main_gt->allocateTensor()) - { - WARNING("Failed to allocate tensor."); - return _main_gt->getGraphStatus(); - } - - return _main_gt->getGraphStatus(); +GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb) +{ + return initialize(&bb, nullptr); } GraphStatus ModelRunnerImpl::run() @@ -156,7 +126,7 @@ done: } template -int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals) +int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals) { if (_main_gt == nullptr) { @@ -197,6 +167,44 @@ int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals) return 0; } +int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()"); + } + + Tensor* tensor; + tensor = _main_gt->getInputTensorByName(input_name); + + if (!tensor) + { + WARNING("Unable to find input tensor %s", input_name.c_str()); + return 1; + } + + int status = 0; + switch (tensor->getDtype()) + { + case DType::DType_FP16: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(half_float::half); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + case DType::DType_FP32: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(float); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + default: + status = 1; + } + + return status; +} + template std::vector ModelRunnerImpl::getOutput(std::string output_name) { @@ -216,7 +224,7 @@ std::vector ModelRunnerImpl::getOutput(std::string output_name) std::vector outputs(tensor->getElementCount()); - if (tensor->writeToVector(outputs)) + if (tensor->writeToVector(ArrayProxy(outputs))) { WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); return std::vector(); @@ -225,6 +233,92 @@ std::vector ModelRunnerImpl::getOutput(std::string output_name) return outputs; } +int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t size) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()"); + } + + Tensor* tensor; + tensor = _main_gt->getOutputTensorByName(output_name); + + if (!tensor) + { + WARNING("Unable to find output tensor %s", output_name.c_str()); + return 1; + } + + int status = 0; + switch (tensor->getDtype()) + { + case DType::DType_FP16: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(half_float::half); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } + case DType::DType_FP32: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(float); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } + default: + status = 1; + } + if (status) + { + WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); + return 1; + } + + return 0; +} + +GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb, + TosaSerializationHandler* serialization_handler) +{ + if (serialization_handler != nullptr) + validateTosaVersion(*serialization_handler); + + // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. + _main_gt = nullptr; + _main_gt = new SubgraphTraverser(bb, serialization_handler); + + if (_main_gt == nullptr) + { + WARNING("An error occurred when generating main graph traverser."); + return GraphStatus::TOSA_ERROR; + } + + if (_main_gt->initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->validateGraph()) + { + WARNING("Failed to validate graph."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->allocateTensor()) + { + WARNING("Failed to allocate tensor."); + return _main_gt->getGraphStatus(); + } + + return _main_gt->getGraphStatus(); +} + void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler) { TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, @@ -266,11 +360,11 @@ void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) } // Template explicit specialization -template int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals); -template int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals); -template int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals); -template int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals); -template int ModelRunnerImpl::setInput(std::string input_name, std::vector& vals); +template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); +template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); +template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); +template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); +template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); template std::vector ModelRunnerImpl::getOutput(std::string output_name); template std::vector ModelRunnerImpl::getOutput(std::string output_name); -- cgit v1.2.1