diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-11-11 18:01:48 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-11-12 11:35:11 +0000 |
commit | 4189cc5ca4bb12e02c5e7f86ec6079f76d845b59 (patch) | |
tree | bb5426804692e11abf000ffe9c64f7d95e104beb /delegate/src/DelegateUtils.hpp | |
parent | 8081536d24291794b4e189e6d5532d913a4525cb (diff) | |
download | armnn-4189cc5ca4bb12e02c5e7f86ec6079f76d845b59.tar.gz |
IVGCVSW-5504 'TfLiteDelegate: Introduce FP16 and BackendOptions'
* Added BackendOptions creations of armnn_delegate
* Included armnn/third-party the armnn_delegate unit tests
* Updated the CreateConstTensor function
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I8e2099a465766b905bff701413307e5850b68e42
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 94 |
1 files changed, 16 insertions, 78 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index fb3f998283..71222276b4 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -397,91 +397,29 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) return ret; } -struct DataHolder +armnn::ConstTensor CreateConstTensor(const TfLiteTensor* tfLiteTensor, + armnn::TensorInfo& tensorInfo, + armnn::Optional<armnn::PermutationVector&> permutationVector) { -public: - DataHolder() - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<float[]>&& data) - : m_Fp32Data(std::move(data)), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<uint8_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(std::move(data)), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<int8_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(std::move(data)), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<int16_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(std::move(data)), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<int32_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(std::move(data)) {} - -private: - std::unique_ptr<float[]> m_Fp32Data; - std::unique_ptr<uint8_t[]> m_Uint8Data; - std::unique_ptr<int8_t[]> m_Int8Data; - std::unique_ptr<int16_t[]> m_Int16Data; - std::unique_ptr<int32_t[]> m_Int32Data; -}; - -template <typename T> -std::pair<armnn::ConstTensor, DataHolder> CreateConstTensorImpl( - const TfLiteTensor* tensor, - armnn::TensorInfo& tensorInfo, - armnn::Optional<armnn::PermutationVector&> permutationVector) -{ - std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]); + if (tfLiteTensor->allocation_type != kTfLiteMmapRo) + { + throw armnn::Exception("TfLiteArmnnDelegate: Not constant allocation type: " + tfLiteTensor->allocation_type); + } + if (permutationVector.has_value() && permutationVector.value().GetSize() > 0) { - tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value()); - armnnUtils::Permute(tensorInfo.GetShape(), + std::vector<uint8_t> swizzledData; + swizzledData.resize(tensorInfo.GetNumBytes()); + armnnUtils::Permute(armnnUtils::Permuted(tensorInfo.GetShape(), permutationVector.value()), permutationVector.value(), - reinterpret_cast<const T*>(tensor->data.raw), data.get(), sizeof(T)); + tfLiteTensor->data.data, + swizzledData.data(), + armnn::GetDataTypeSize(tensorInfo.GetDataType())); + return armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, permutationVector.value()), swizzledData.data()); } else { - ::memcpy(data.get(), tensor->data.raw, tensorInfo.GetNumBytes()); - } - - auto constData = std::make_pair(armnn::ConstTensor(tensorInfo, data.get()), std::move(data)); - - DataHolder storedData(std::move(constData.second)); - return std::make_pair(constData.first, std::move(storedData)); -} - -std::pair<armnn::ConstTensor, DataHolder> CreateConstTensor( - const TfLiteTensor* tfLiteTensor, - armnn::TensorInfo& tensorInfo, - armnn::Optional<armnn::PermutationVector&> permutationVector) -{ - switch (tensorInfo.GetDataType()) - { - case armnn::DataType::Float32: - return CreateConstTensorImpl<float>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QAsymmU8: - return CreateConstTensorImpl<uint8_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QSymmS8: - return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QAsymmS8: - return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QSymmS16: - return CreateConstTensorImpl<int16_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::Signed32: - return CreateConstTensorImpl<int32_t>(tfLiteTensor, tensorInfo, permutationVector); - default: - { - throw armnn::Exception( - "TfLiteArmnnDelegate: Unsupported data type when creating const tensor: " - + std::string(armnn::GetDataTypeName(tensorInfo.GetDataType()))); - } + return armnn::ConstTensor(tensorInfo, tfLiteTensor->data.data); } } |