diff options
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 94 |
1 files changed, 16 insertions, 78 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index fb3f998283..71222276b4 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -397,91 +397,29 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) return ret; } -struct DataHolder +armnn::ConstTensor CreateConstTensor(const TfLiteTensor* tfLiteTensor, + armnn::TensorInfo& tensorInfo, + armnn::Optional<armnn::PermutationVector&> permutationVector) { -public: - DataHolder() - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<float[]>&& data) - : m_Fp32Data(std::move(data)), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<uint8_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(std::move(data)), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<int8_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(std::move(data)), m_Int16Data(nullptr), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<int16_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(std::move(data)), m_Int32Data(nullptr) {} - - DataHolder(std::unique_ptr<int32_t[]>&& data) - : m_Fp32Data(nullptr), m_Uint8Data(nullptr), - m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(std::move(data)) {} - -private: - std::unique_ptr<float[]> m_Fp32Data; - std::unique_ptr<uint8_t[]> m_Uint8Data; - std::unique_ptr<int8_t[]> m_Int8Data; - std::unique_ptr<int16_t[]> m_Int16Data; - std::unique_ptr<int32_t[]> m_Int32Data; -}; - -template <typename T> -std::pair<armnn::ConstTensor, DataHolder> CreateConstTensorImpl( - const TfLiteTensor* tensor, - armnn::TensorInfo& tensorInfo, - armnn::Optional<armnn::PermutationVector&> permutationVector) -{ - std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]); + if (tfLiteTensor->allocation_type != kTfLiteMmapRo) + { + throw armnn::Exception("TfLiteArmnnDelegate: Not constant allocation type: " + tfLiteTensor->allocation_type); + } + if (permutationVector.has_value() && permutationVector.value().GetSize() > 0) { - tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value()); - armnnUtils::Permute(tensorInfo.GetShape(), + std::vector<uint8_t> swizzledData; + swizzledData.resize(tensorInfo.GetNumBytes()); + armnnUtils::Permute(armnnUtils::Permuted(tensorInfo.GetShape(), permutationVector.value()), permutationVector.value(), - reinterpret_cast<const T*>(tensor->data.raw), data.get(), sizeof(T)); + tfLiteTensor->data.data, + swizzledData.data(), + armnn::GetDataTypeSize(tensorInfo.GetDataType())); + return armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, permutationVector.value()), swizzledData.data()); } else { - ::memcpy(data.get(), tensor->data.raw, tensorInfo.GetNumBytes()); - } - - auto constData = std::make_pair(armnn::ConstTensor(tensorInfo, data.get()), std::move(data)); - - DataHolder storedData(std::move(constData.second)); - return std::make_pair(constData.first, std::move(storedData)); -} - -std::pair<armnn::ConstTensor, DataHolder> CreateConstTensor( - const TfLiteTensor* tfLiteTensor, - armnn::TensorInfo& tensorInfo, - armnn::Optional<armnn::PermutationVector&> permutationVector) -{ - switch (tensorInfo.GetDataType()) - { - case armnn::DataType::Float32: - return CreateConstTensorImpl<float>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QAsymmU8: - return CreateConstTensorImpl<uint8_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QSymmS8: - return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QAsymmS8: - return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::QSymmS16: - return CreateConstTensorImpl<int16_t>(tfLiteTensor, tensorInfo, permutationVector); - case armnn::DataType::Signed32: - return CreateConstTensorImpl<int32_t>(tfLiteTensor, tensorInfo, permutationVector); - default: - { - throw armnn::Exception( - "TfLiteArmnnDelegate: Unsupported data type when creating const tensor: " - + std::string(armnn::GetDataTypeName(tensorInfo.GetDataType()))); - } + return armnn::ConstTensor(tensorInfo, tfLiteTensor->data.data); } } |