aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp')
-rw-r--r--src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp101
1 files changed, 100 insertions, 1 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
index 3e106e1fd5..e43f6ca027 100644
--- a/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
+++ b/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -354,3 +354,102 @@ inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_
tensorHandle->Unmap();
return uint8Data;
}
+
+
+inline std::vector<uint8_t> CreateConstTosaData(const void* value,
+ DType dtype,
+ const std::vector<int32_t>& shape)
+{
+ std::vector<uint8_t> uint8Data;
+ tosa_err_t error = tosa_err_t::TOSA_OK;
+
+ unsigned int numElements = 1;
+ for (auto s : shape)
+ {
+ if (s < 0)
+ {
+ throw armnn::Exception("CreateConstTosaData: negative shape elements unhandled.");
+ }
+ numElements = numElements * static_cast<unsigned int>(s);
+ }
+
+ switch (dtype)
+ {
+ case DType::DType_FP32:
+ {
+ std::vector<float> data(numElements, *static_cast<const float*>(value));
+ error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_FP16:
+ {
+ std::vector<float> data(numElements, *static_cast<const float*>(value));
+ error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_INT48:
+ {
+ std::vector<int64_t> data(numElements, *static_cast<const int64_t*>(value));
+ error = TosaSerializationHandler::ConvertI48toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_INT32:
+ {
+ std::vector<int32_t> data(numElements, *static_cast<const int32_t*>(value));
+ error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_INT16:
+ {
+ std::vector<int16_t> data(numElements, *static_cast<const int16_t*>(value));
+ error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_INT8:
+ {
+ std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
+ error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_INT4:
+ {
+ std::vector<int8_t> data(numElements, *static_cast<const int8_t*>(value));
+ error = TosaSerializationHandler::ConvertI4toU8(data, uint8Data);
+ break;
+ }
+ case DType::DType_BOOL:
+ {
+ std::vector<bool> data(numElements, *static_cast<const bool*>(value));
+ error = TosaSerializationHandler::ConvertBooltoU8(data, uint8Data);
+ break;
+ }
+ default:
+ {
+ throw armnn::Exception("CreateConstTosaData: An unsupported data type was encountered.");
+ }
+ }
+
+ if(error != tosa_err_t::TOSA_OK)
+ {
+ throw armnn::Exception("CreateConstTosaData: An error occurred when converting constant data");
+ }
+
+ return uint8Data;
+}
+
+template<typename T>
+inline void CreateConstTosaOperator(const std::string& outputName,
+ const T value,
+ DType dtype,
+ const std::vector<int32_t>& shape,
+ TosaSerializationOperator*& op,
+ TosaSerializationTensor*& tensor)
+{
+ std::vector<uint8_t> uint8Data = CreateConstTosaData(static_cast<const void *>(&value), dtype, shape);
+
+ op = new TosaSerializationOperator(Op_CONST, Attribute_NONE, nullptr, {}, {outputName});
+ ARMNN_THROW_MSG_IF_FALSE(op, armnn::Exception, "CreateConstTosaOperator: failed to created operator");
+
+ tensor = new TosaSerializationTensor(outputName, shape, dtype, uint8Data);
+ ARMNN_THROW_MSG_IF_FALSE(tensor, armnn::Exception, "CreateConstTosaOperator: failed to created tensor");
+} \ No newline at end of file