diff options
Diffstat (limited to 'reference_model/src/model_runner_impl.cc')
-rw-r--r-- | reference_model/src/model_runner_impl.cc | 178 |
1 files changed, 136 insertions, 42 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index 8427150..1109dd6 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -45,42 +45,12 @@ void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) { validateTosaVersion(serialization_handler); + return initialize(serialization_handler.GetMainBlock(), &serialization_handler); +} - // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. - _main_gt = nullptr; - _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler); - - if(_main_gt == nullptr) - { - WARNING("An error occurred when generating main graph traverser."); - return GraphStatus::TOSA_ERROR; - } - - if (_main_gt->initializeGraph()) - { - WARNING("Unable to initialize main graph traverser."); - return _main_gt->getGraphStatus(); - } - - if (_main_gt->linkTensorsAndNodes()) - { - WARNING("Failed to link tensors and nodes"); - return _main_gt->getGraphStatus(); - } - - if (_main_gt->validateGraph()) - { - WARNING("Failed to validate graph."); - return _main_gt->getGraphStatus(); - } - - if (_main_gt->allocateTensor()) - { - WARNING("Failed to allocate tensor."); - return _main_gt->getGraphStatus(); - } - - return _main_gt->getGraphStatus(); +GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb) +{ + return initialize(&bb, nullptr); } GraphStatus ModelRunnerImpl::run() @@ -156,7 +126,7 @@ done: } template <typename T> -int ModelRunnerImpl::setInput(std::string input_name, std::vector<T>& vals) +int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy<T> vals) { if (_main_gt == nullptr) { @@ -197,6 +167,44 @@ int ModelRunnerImpl::setInput(std::string input_name, std::vector<T>& vals) return 0; } +int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()"); + } + + Tensor* tensor; + tensor = _main_gt->getInputTensorByName(input_name); + + if (!tensor) + { + WARNING("Unable to find input tensor %s", input_name.c_str()); + return 1; + } + + int status = 0; + switch (tensor->getDtype()) + { + case DType::DType_FP16: { + auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr); + const int elements = size / sizeof(half_float::half); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + case DType::DType_FP32: { + auto typed_ptr = reinterpret_cast<float*>(raw_ptr); + const int elements = size / sizeof(float); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + default: + status = 1; + } + + return status; +} + template <typename T> std::vector<T> ModelRunnerImpl::getOutput(std::string output_name) { @@ -216,7 +224,7 @@ std::vector<T> ModelRunnerImpl::getOutput(std::string output_name) std::vector<T> outputs(tensor->getElementCount()); - if (tensor->writeToVector(outputs)) + if (tensor->writeToVector(ArrayProxy<T>(outputs))) { WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); return std::vector<T>(); @@ -225,6 +233,92 @@ std::vector<T> ModelRunnerImpl::getOutput(std::string output_name) return outputs; } +int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t size) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()"); + } + + Tensor* tensor; + tensor = _main_gt->getOutputTensorByName(output_name); + + if (!tensor) + { + WARNING("Unable to find output tensor %s", output_name.c_str()); + return 1; + } + + int status = 0; + switch (tensor->getDtype()) + { + case DType::DType_FP16: { + auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr); + const int elements = size / sizeof(half_float::half); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } + case DType::DType_FP32: { + auto typed_ptr = reinterpret_cast<float*>(raw_ptr); + const int elements = size / sizeof(float); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } + default: + status = 1; + } + if (status) + { + WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); + return 1; + } + + return 0; +} + +GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb, + TosaSerializationHandler* serialization_handler) +{ + if (serialization_handler != nullptr) + validateTosaVersion(*serialization_handler); + + // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. + _main_gt = nullptr; + _main_gt = new SubgraphTraverser(bb, serialization_handler); + + if (_main_gt == nullptr) + { + WARNING("An error occurred when generating main graph traverser."); + return GraphStatus::TOSA_ERROR; + } + + if (_main_gt->initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->validateGraph()) + { + WARNING("Failed to validate graph."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->allocateTensor()) + { + WARNING("Failed to allocate tensor."); + return _main_gt->getGraphStatus(); + } + + return _main_gt->getGraphStatus(); +} + void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler) { TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, @@ -266,11 +360,11 @@ void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) } // Template explicit specialization -template int ModelRunnerImpl::setInput<float>(std::string input_name, std::vector<float>& vals); -template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, std::vector<half_float::half>& vals); -template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, std::vector<int32_t>& vals); -template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, std::vector<int64_t>& vals); -template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, std::vector<unsigned char>& vals); +template int ModelRunnerImpl::setInput<float>(std::string input_name, ArrayProxy<float> vals); +template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, ArrayProxy<half_float::half> vals); +template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, ArrayProxy<int32_t> vals); +template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, ArrayProxy<int64_t> vals); +template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, ArrayProxy<unsigned char> vals); template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name); template std::vector<half_float::half> ModelRunnerImpl::getOutput<half_float::half>(std::string output_name); |