From 7203835d65afb8a06a5cf98072718d93a9b71567 Mon Sep 17 00:00:00 2001 From: Fabrizio Indirli Date: Mon, 11 Dec 2023 11:15:32 +0000 Subject: Add support for precise mode in eager runner Add support for Fp64 tensors in the eager runner's helper functions, when precise mode is enabled. Signed-off-by: Fabrizio Indirli Change-Id: Ib737c0d18fb1c7ac40ce6ea03a4fbcefae88ba5c --- reference_model/src/model_runner_impl.cc | 56 ++++++++++- reference_model/src/model_runner_impl.h | 3 +- reference_model/src/tensor.cc | 164 ++++++++++++++++++++++++++++++- 3 files changed, 219 insertions(+), 4 deletions(-) diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index 311db7c..bf23bac 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2022-2023, ARM Limited. +// Copyright (c) 2022-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -166,6 +166,40 @@ int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals) return 0; } +int ModelRunnerImpl::setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size) +{ + ASSERT_MSG(tensor, "Tensor not provided!"); + if (!g_func_config.precise_mode) + { + WARNING("Cannot set input tensor %s using precise mode setters when not running in precise mode!", + input_name.c_str()); + return 1; + } + + DType ser_dtype = tensor->getSerializationDtype(); + int status; + + switch (ser_dtype) + { + case DType::DType_FP16: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(half_float::half); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + case DType::DType_FP32: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(float); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + break; + } + default: + status = 1; + } + + return status; +} + int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t size) { if (_main_gt == nullptr) @@ -197,6 +231,18 @@ int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t s status = setInput(input_name, ArrayProxy(elements, typed_ptr)); break; } + case TOSA_REF_TYPE_FP64: + if (g_func_config.precise_mode) + { + status = setInputForPrecMode(tensor, input_name, raw_ptr, size); + } + else + { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(double); + status = setInput(input_name, ArrayProxy(elements, typed_ptr)); + } + break; case TOSA_REF_TYPE_INT16: { auto typed_ptr = reinterpret_cast(raw_ptr); const int elements = size / sizeof(int16_t); @@ -281,6 +327,12 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); break; } + case TOSA_REF_TYPE_FP64: { + auto typed_ptr = reinterpret_cast(raw_ptr); + const int elements = size / sizeof(double); + status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); + break; + } case TOSA_REF_TYPE_BOOL: { auto typed_ptr = reinterpret_cast(raw_ptr); const int elements = size / sizeof(unsigned char); @@ -394,12 +446,14 @@ void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) } // Template explicit specialization +template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); template int ModelRunnerImpl::setInput(std::string input_name, ArrayProxy vals); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); template std::vector ModelRunnerImpl::getOutput(std::string output_name); template std::vector ModelRunnerImpl::getOutput(std::string output_name); template std::vector ModelRunnerImpl::getOutput(std::string output_name); diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h index aed8a1e..db9755c 100644 --- a/reference_model/src/model_runner_impl.h +++ b/reference_model/src/model_runner_impl.h @@ -1,5 +1,5 @@ -// Copyright (c) 2022-2023, ARM Limited. +// Copyright (c) 2022-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -64,6 +64,7 @@ private: GraphStatus initialize(TosaSerializationBasicBlock* bb, TosaSerializationHandler* serialization_handler); void validateTosaVersion(TosaSerializationHandler& serialization_handler); void checkGraphStatus(SubgraphTraverser& main_gt); + int setInputForPrecMode(Tensor* tensor, std::string input_name, uint8_t* raw_ptr, size_t size); }; }; // namespace TosaReference diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 645b55f..e84507b 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -580,6 +580,14 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) uint32_t elements = getElementCount(); switch (getDtype()) { + case TOSA_REF_TYPE_FP64: + if (!g_func_config.precise_mode) + { + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameTOSAREFTYPE(getDtype())); + return -2; + } + // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_FP32: if (vals.size() != elements) @@ -622,6 +630,14 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy val switch (getDtype()) { + case TOSA_REF_TYPE_FP64: + if (!g_func_config.precise_mode) + { + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameTOSAREFTYPE(getDtype())); + return -2; + } + // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: if (vals.size() != elements) { @@ -953,7 +969,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) template int TosaReference::TensorTemplate::setTensorValueDouble(const size_t buflen, const double* vals) { - FATAL_ERROR("TensorTemplate::setTensorValueFloat should not be called. " + FATAL_ERROR("TensorTemplate::setTensorValueDouble should not be called. " "Implement template specialization version."); return 0; } @@ -1254,6 +1270,150 @@ int TosaReference::Tensor6::setTensorValueFloat(const size_t bufLen, cons return 0; } +template <> +int TosaReference::Tensor0::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + (*tensor)(0) = vals[0]; + + return 0; +} + +template <> +int TosaReference::Tensor1::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + (*tensor)(i0) = vals[idx++]; + } + + return 0; +} + +template <> +int TosaReference::Tensor2::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + (*tensor)(i0, i1) = vals[idx++]; + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + (*tensor)(i0, i1, i2) = vals[idx++]; + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + (*tensor)(i0, i1, i2, i3) = vals[idx++]; + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + template int TosaReference::TensorTemplate::setTensorValueInt16(const size_t bufLen, const int16_t* vals) { -- cgit v1.2.1