diff options
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp')
-rw-r--r-- | src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp | 101 |
1 files changed, 100 insertions, 1 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp index 3e106e1fd5..e43f6ca027 100644 --- a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp +++ b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -354,3 +354,102 @@ inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ tensorHandle->Unmap(); return uint8Data; } + + +inline std::vector<uint8_t> CreateConstTosaData(const void* value, + DType dtype, + const std::vector<int32_t>& shape) +{ + std::vector<uint8_t> uint8Data; + tosa_err_t error = tosa_err_t::TOSA_OK; + + unsigned int numElements = 1; + for (auto s : shape) + { + if (s < 0) + { + throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled."); + } + numElements = numElements * static_cast<unsigned int>(s); + } + + switch (dtype) + { + case DType::DType_FP32: + { + std::vector<float> data(numElements, *static_cast<const float*>(value)); + error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data); + break; + } + case DType::DType_FP16: + { + std::vector<float> data(numElements, *static_cast<const float*>(value)); + error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data); + break; + } + case DType::DType_INT48: + { + std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value)); + error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data); + break; + } + case DType::DType_INT32: + { + std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value)); + error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data); + break; + } + case DType::DType_INT16: + { + std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value)); + error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data); + break; + } + case DType::DType_INT8: + { + std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value)); + error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data); + break; + } + case DType::DType_INT4: + { + std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value)); + error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data); + break; + } + case DType::DType_BOOL: + { + std::vector<bool> data(numElements, *static_cast<const bool*>(value)); + error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data); + break; + } + default: + { + throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered."); + } + } + + if(error != tosa_err_t::TOSA_OK) + { + throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data"); + } + + return uint8Data; +} + +template<typename T> +inline void CreateConstTosaOperator(const std::string& outputName, + const T value, + DType dtype, + const std::vector<int32_t>& shape, + TosaSerializationOperator*& op, + TosaSerializationTensor*& tensor) +{ + std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape); + + op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName}); + ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator"); + + tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data); + ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor"); +}
\ No newline at end of file |