aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/model_runner_impl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/model_runner_impl.cc')
-rw-r--r--reference_model/src/model_runner_impl.cc178
1 files changed, 136 insertions, 42 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 8427150..1109dd6 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -45,42 +45,12 @@ void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug)
GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler)
{
validateTosaVersion(serialization_handler);
+ return initialize(serialization_handler.GetMainBlock(), &serialization_handler);
+}
- // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
- _main_gt = nullptr;
- _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler);
-
- if(_main_gt == nullptr)
- {
- WARNING("An error occurred when generating main graph traverser.");
- return GraphStatus::TOSA_ERROR;
- }
-
- if (_main_gt->initializeGraph())
- {
- WARNING("Unable to initialize main graph traverser.");
- return _main_gt->getGraphStatus();
- }
-
- if (_main_gt->linkTensorsAndNodes())
- {
- WARNING("Failed to link tensors and nodes");
- return _main_gt->getGraphStatus();
- }
-
- if (_main_gt->validateGraph())
- {
- WARNING("Failed to validate graph.");
- return _main_gt->getGraphStatus();
- }
-
- if (_main_gt->allocateTensor())
- {
- WARNING("Failed to allocate tensor.");
- return _main_gt->getGraphStatus();
- }
-
- return _main_gt->getGraphStatus();
+GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock& bb)
+{
+ return initialize(&bb, nullptr);
}
GraphStatus ModelRunnerImpl::run()
@@ -156,7 +126,7 @@ done:
}
template <typename T>
-int ModelRunnerImpl::setInput(std::string input_name, std::vector<T>& vals)
+int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy<T> vals)
{
if (_main_gt == nullptr)
{
@@ -197,6 +167,44 @@ int ModelRunnerImpl::setInput(std::string input_name, std::vector<T>& vals)
return 0;
}
+int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size)
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()");
+ }
+
+ Tensor* tensor;
+ tensor = _main_gt->getInputTensorByName(input_name);
+
+ if (!tensor)
+ {
+ WARNING("Unable to find input tensor %s", input_name.c_str());
+ return 1;
+ }
+
+ int status = 0;
+ switch (tensor->getDtype())
+ {
+ case DType::DType_FP16: {
+ auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
+ const int elements = size / sizeof(half_float::half);
+ status = setInput(input_name, ArrayProxy(elements, typed_ptr));
+ break;
+ }
+ case DType::DType_FP32: {
+ auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
+ const int elements = size / sizeof(float);
+ status = setInput(input_name, ArrayProxy(elements, typed_ptr));
+ break;
+ }
+ default:
+ status = 1;
+ }
+
+ return status;
+}
+
template <typename T>
std::vector<T> ModelRunnerImpl::getOutput(std::string output_name)
{
@@ -216,7 +224,7 @@ std::vector<T> ModelRunnerImpl::getOutput(std::string output_name)
std::vector<T> outputs(tensor->getElementCount());
- if (tensor->writeToVector(outputs))
+ if (tensor->writeToVector(ArrayProxy<T>(outputs)))
{
WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str());
return std::vector<T>();
@@ -225,6 +233,92 @@ std::vector<T> ModelRunnerImpl::getOutput(std::string output_name)
return outputs;
}
+int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t size)
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()");
+ }
+
+ Tensor* tensor;
+ tensor = _main_gt->getOutputTensorByName(output_name);
+
+ if (!tensor)
+ {
+ WARNING("Unable to find output tensor %s", output_name.c_str());
+ return 1;
+ }
+
+ int status = 0;
+ switch (tensor->getDtype())
+ {
+ case DType::DType_FP16: {
+ auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
+ const int elements = size / sizeof(half_float::half);
+ status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
+ break;
+ }
+ case DType::DType_FP32: {
+ auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
+ const int elements = size / sizeof(float);
+ status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
+ break;
+ }
+ default:
+ status = 1;
+ }
+ if (status)
+ {
+ WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str());
+ return 1;
+ }
+
+ return 0;
+}
+
+GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb,
+ TosaSerializationHandler* serialization_handler)
+{
+ if (serialization_handler != nullptr)
+ validateTosaVersion(*serialization_handler);
+
+ // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
+ _main_gt = nullptr;
+ _main_gt = new SubgraphTraverser(bb, serialization_handler);
+
+ if (_main_gt == nullptr)
+ {
+ WARNING("An error occurred when generating main graph traverser.");
+ return GraphStatus::TOSA_ERROR;
+ }
+
+ if (_main_gt->initializeGraph())
+ {
+ WARNING("Unable to initialize main graph traverser.");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->linkTensorsAndNodes())
+ {
+ WARNING("Failed to link tensors and nodes");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->validateGraph())
+ {
+ WARNING("Failed to validate graph.");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->allocateTensor())
+ {
+ WARNING("Failed to allocate tensor.");
+ return _main_gt->getGraphStatus();
+ }
+
+ return _main_gt->getGraphStatus();
+}
+
void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler)
{
TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR,
@@ -266,11 +360,11 @@ void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt)
}
// Template explicit specialization
-template int ModelRunnerImpl::setInput<float>(std::string input_name, std::vector<float>& vals);
-template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, std::vector<half_float::half>& vals);
-template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, std::vector<int32_t>& vals);
-template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, std::vector<int64_t>& vals);
-template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, std::vector<unsigned char>& vals);
+template int ModelRunnerImpl::setInput<float>(std::string input_name, ArrayProxy<float> vals);
+template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, ArrayProxy<half_float::half> vals);
+template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, ArrayProxy<int32_t> vals);
+template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, ArrayProxy<int64_t> vals);
+template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, ArrayProxy<unsigned char> vals);
template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name);
template std::vector<half_float::half> ModelRunnerImpl::getOutput<half_float::half>(std::string output_name);