diff options
Diffstat (limited to 'reference_model/src/model_runner_impl.cc')
-rw-r--r-- | reference_model/src/model_runner_impl.cc | 56 |
1 files changed, 55 insertions, 1 deletions
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index 311db7c..bf23bac 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2022-2023, ARM Limited. +// Copyright (c) 2022-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -166,6 +166,40 @@ int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy<T> vals) return 0; } +int ModelRunnerImpl::setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size) +{ + ASSERT_MSG(tensor, "Tensor not provided!"); + if (!g_func_config.precise_mode) + { + WARNING("Cannot set input tensor %s using precise mode setters when not running in precise mode!", + input_name.c_str()); + return 1; + } + + DType ser_dtype = tensor->getSerializationDtype(); + int status; + + switch (ser_dtype) + { + case DType::DType_FP16: { + auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr); + const int elements = size / sizeof(half_float::half); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + case DType::DType_FP32: { + auto typed_ptr = reinterpret_cast<float*>(raw_ptr); + const int elements = size / sizeof(float); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + default: + status = 1; + } + + return status; +} + int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size) { if (_main_gt == nullptr) @@ -197,6 +231,18 @@ int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t s status = setInput(input_name, ArrayProxy(elements, typed_ptr)); break; } + case TOSA_REF_TYPE_FP64: + if (g_func_config.precise_mode) + { + status = setInputForPrecMode(tensor, input_name, raw_ptr, size); + } + else + { + auto typed_ptr = reinterpret_cast<double*>(raw_ptr); + const int elements = size / sizeof(double); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + } + break; case TOSA_REF_TYPE_INT16: { auto typed_ptr = reinterpret_cast<int16_t*>(raw_ptr); const int elements = size / sizeof(int16_t); @@ -281,6 +327,12 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); break; } + case TOSA_REF_TYPE_FP64: { + auto typed_ptr = reinterpret_cast<double*>(raw_ptr); + const int elements = size / sizeof(double); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } case TOSA_REF_TYPE_BOOL: { auto typed_ptr = reinterpret_cast<unsigned char*>(raw_ptr); const int elements = size / sizeof(unsigned char); @@ -394,12 +446,14 @@ void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) } // Template explicit specialization +template int ModelRunnerImpl::setInput<double>(std::string input_name, ArrayProxy<double> vals); template int ModelRunnerImpl::setInput<float>(std::string input_name, ArrayProxy<float> vals); template int ModelRunnerImpl::setInput<half_float::half>(std::string input_name, ArrayProxy<half_float::half> vals); template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, ArrayProxy<int32_t> vals); template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, ArrayProxy<int64_t> vals); template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, ArrayProxy<unsigned char> vals); +template std::vector<double> ModelRunnerImpl::getOutput<double>(std::string output_name); template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name); template std::vector<half_float::half> ModelRunnerImpl::getOutput<half_float::half>(std::string output_name); template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name); |