aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/model_runner_impl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/model_runner_impl.cc')
-rw-r--r--reference_model/src/model_runner_impl.cc56
1 files changed, 55 insertions, 1 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index 311db7c..bf23bac 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2022-2023, ARM Limited.
+// Copyright (c) 2022-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -166,6 +166,40 @@ int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy<T> vals)
return 0;
}
+int ModelRunnerImpl::setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size)
+{
+ ASSERT_MSG(tensor, "Tensor not provided!");
+ if (!g_func_config.precise_mode)
+ {
+ WARNING("Cannot set input tensor %s using precise mode setters when not running in precise mode!",
+ input_name.c_str());
+ return 1;
+ }
+
+ DType ser_dtype = tensor->getSerializationDtype();
+ int status;
+
+ switch (ser_dtype)
+ {
+ case DType::DType_FP16: {
+ auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
+ const int elements = size / sizeof(half_float::half);
+ status = setInput(input_name, ArrayProxy(elements, typed_ptr));
+ break;
+ }
+ case DType::DType_FP32: {
+ auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
+ const int elements = size / sizeof(float);
+ status = setInput(input_name, ArrayProxy(elements, typed_ptr));
+ break;
+ }
+ default:
+ status = 1;
+ }
+
+ return status;
+}
+
int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size)
{
if (_main_gt == nullptr)
@@ -197,6 +231,18 @@ int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t s
status = setInput(input_name, ArrayProxy(elements, typed_ptr));
break;
}
+ case TOSA_REF_TYPE_FP64:
+ if (g_func_config.precise_mode)
+ {
+ status = setInputForPrecMode(tensor, input_name, raw_ptr, size);
+ }
+ else
+ {
+ auto typed_ptr = reinterpret_cast<double*>(raw_ptr);
+ const int elements = size / sizeof(double);
+ status = setInput(input_name, ArrayProxy(elements, typed_ptr));
+ }
+ break;
case TOSA_REF_TYPE_INT16: {
auto typed_ptr = reinterpret_cast<int16_t*>(raw_ptr);
const int elements = size / sizeof(int16_t);
@@ -281,6 +327,12 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t
status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
break;
}
+ case TOSA_REF_TYPE_FP64: {
+ auto typed_ptr = reinterpret_cast<double*>(raw_ptr);
+ const int elements = size / sizeof(double);
+ status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
+ break;
+ }
case TOSA_REF_TYPE_BOOL: {
auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr);
const int elements = size / sizeof(unsigned char);
@@ -394,12 +446,14 @@ void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt)
}
// Template explicit specialization
+template int ModelRunnerImpl::setInput<double>(std::string input_name, ArrayProxy<double> vals);
template int ModelRunnerImpl::setInput<float>(std::string input_name, ArrayProxy<float> vals);
template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, ArrayProxy<half_float::half> vals);
template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, ArrayProxy<int32_t> vals);
template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, ArrayProxy<int64_t> vals);
template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, ArrayProxy<unsigned char> vals);
+template std::vector<double> ModelRunnerImpl::getOutput<double>(std::string output_name);
template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name);
template std::vector<half_float::half> ModelRunnerImpl::getOutput<half_float::half>(std::string output_name);
template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name);