diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2023-12-15 14:20:47 +0000 |
---|---|---|
committer | john.mcloughlin <john.mcloughlin@arm.com> | 2023-12-21 11:14:54 +0000 |
commit | ca5a23a7cbe46b8da8de432d80889c47a745ca4c (patch) | |
tree | c375efe1f4ecdad708d6e8e771a13c07ddbc0257 /src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp | |
parent | 0587dd01272199a36210bd0ccc266185b113df75 (diff) | |
download | armnn-ca5a23a7cbe46b8da8de432d80889c47a745ca4c.tar.gz |
Add Quantize Support to TOSA Ref Backend
* Adding a one to many tosa mapping for Quantize
* Added tests
* Resolves IVGCVSW-7175
Signed-off-by: John Mcloughlin <john.mcloughlin@arm.com>
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Ia0852fefb618b4a29c2601b9de8b6b2731229801
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp')
-rw-r--r-- | src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp | 101 |
1 files changed, 100 insertions, 1 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp index 3e106e1fd5..e43f6ca027 100644 --- a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp +++ b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -354,3 +354,102 @@ inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ tensorHandle->Unmap(); return uint8Data; } + + +inline std::vector<uint8_t> CreateConstTosaData(const void* value, + DType dtype, + const std::vector<int32_t>& shape) +{ + std::vector<uint8_t> uint8Data; + tosa_err_t error = tosa_err_t::TOSA_OK; + + unsigned int numElements = 1; + for (auto s : shape) + { + if (s < 0) + { + throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled."); + } + numElements = numElements * static_cast<unsigned int>(s); + } + + switch (dtype) + { + case DType::DType_FP32: + { + std::vector<float> data(numElements, *static_cast<const float*>(value)); + error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data); + break; + } + case DType::DType_FP16: + { + std::vector<float> data(numElements, *static_cast<const float*>(value)); + error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data); + break; + } + case DType::DType_INT48: + { + std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value)); + error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data); + break; + } + case DType::DType_INT32: + { + std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value)); + error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data); + break; + } + case DType::DType_INT16: + { + std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value)); + error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data); + break; + } + case DType::DType_INT8: + { + std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value)); + error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data); + break; + } + case DType::DType_INT4: + { + std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value)); + error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data); + break; + } + case DType::DType_BOOL: + { + std::vector<bool> data(numElements, *static_cast<const bool*>(value)); + error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data); + break; + } + default: + { + throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered."); + } + } + + if(error != tosa_err_t::TOSA_OK) + { + throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data"); + } + + return uint8Data; +} + +template<typename T> +inline void CreateConstTosaOperator(const std::string& outputName, + const T value, + DType dtype, + const std::vector<int32_t>& shape, + TosaSerializationOperator*& op, + TosaSerializationTensor*& tensor) +{ + std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape); + + op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName}); + ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator"); + + tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data); + ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor"); +}
\ No newline at end of file |