aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r--delegate/src/DelegateUtils.hpp94
1 files changed, 16 insertions, 78 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp
index fb3f998283..71222276b4 100644
--- a/delegate/src/DelegateUtils.hpp
+++ b/delegate/src/DelegateUtils.hpp
@@ -397,91 +397,29 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor)
return ret;
}
-struct DataHolder
+armnn::ConstTensor CreateConstTensor(const TfLiteTensor* tfLiteTensor,
+ armnn::TensorInfo& tensorInfo,
+ armnn::Optional<armnn::PermutationVector&> permutationVector)
{
-public:
- DataHolder()
- : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
- m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
-
- DataHolder(std::unique_ptr<float[]>&& data)
- : m_Fp32Data(std::move(data)), m_Uint8Data(nullptr),
- m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
-
- DataHolder(std::unique_ptr<uint8_t[]>&& data)
- : m_Fp32Data(nullptr), m_Uint8Data(std::move(data)),
- m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
-
- DataHolder(std::unique_ptr<int8_t[]>&& data)
- : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
- m_Int8Data(std::move(data)), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
-
- DataHolder(std::unique_ptr<int16_t[]>&& data)
- : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
- m_Int8Data(nullptr), m_Int16Data(std::move(data)), m_Int32Data(nullptr) {}
-
- DataHolder(std::unique_ptr<int32_t[]>&& data)
- : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
- m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(std::move(data)) {}
-
-private:
- std::unique_ptr<float[]> m_Fp32Data;
- std::unique_ptr<uint8_t[]> m_Uint8Data;
- std::unique_ptr<int8_t[]> m_Int8Data;
- std::unique_ptr<int16_t[]> m_Int16Data;
- std::unique_ptr<int32_t[]> m_Int32Data;
-};
-
-template <typename T>
-std::pair<armnn::ConstTensor, DataHolder> CreateConstTensorImpl(
- const TfLiteTensor* tensor,
- armnn::TensorInfo& tensorInfo,
- armnn::Optional<armnn::PermutationVector&> permutationVector)
-{
- std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
+ if (tfLiteTensor->allocation_type != kTfLiteMmapRo)
+ {
+ throw armnn::Exception("TfLiteArmnnDelegate: Not constant allocation type: " + tfLiteTensor->allocation_type);
+ }
+
if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
{
- tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
- armnnUtils::Permute(tensorInfo.GetShape(),
+ std::vector<uint8_t> swizzledData;
+ swizzledData.resize(tensorInfo.GetNumBytes());
+ armnnUtils::Permute(armnnUtils::Permuted(tensorInfo.GetShape(), permutationVector.value()),
permutationVector.value(),
- reinterpret_cast<const T*>(tensor->data.raw), data.get(), sizeof(T));
+ tfLiteTensor->data.data,
+ swizzledData.data(),
+ armnn::GetDataTypeSize(tensorInfo.GetDataType()));
+ return armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, permutationVector.value()), swizzledData.data());
}
else
{
- ::memcpy(data.get(), tensor->data.raw, tensorInfo.GetNumBytes());
- }
-
- auto constData = std::make_pair(armnn::ConstTensor(tensorInfo, data.get()), std::move(data));
-
- DataHolder storedData(std::move(constData.second));
- return std::make_pair(constData.first, std::move(storedData));
-}
-
-std::pair<armnn::ConstTensor, DataHolder> CreateConstTensor(
- const TfLiteTensor* tfLiteTensor,
- armnn::TensorInfo& tensorInfo,
- armnn::Optional<armnn::PermutationVector&> permutationVector)
-{
- switch (tensorInfo.GetDataType())
- {
- case armnn::DataType::Float32:
- return CreateConstTensorImpl<float>(tfLiteTensor, tensorInfo, permutationVector);
- case armnn::DataType::QAsymmU8:
- return CreateConstTensorImpl<uint8_t>(tfLiteTensor, tensorInfo, permutationVector);
- case armnn::DataType::QSymmS8:
- return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector);
- case armnn::DataType::QAsymmS8:
- return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector);
- case armnn::DataType::QSymmS16:
- return CreateConstTensorImpl<int16_t>(tfLiteTensor, tensorInfo, permutationVector);
- case armnn::DataType::Signed32:
- return CreateConstTensorImpl<int32_t>(tfLiteTensor, tensorInfo, permutationVector);
- default:
- {
- throw armnn::Exception(
- "TfLiteArmnnDelegate: Unsupported data type when creating const tensor: "
- + std::string(armnn::GetDataTypeName(tensorInfo.GetDataType())));
- }
+ return armnn::ConstTensor(tensorInfo, tfLiteTensor->data.data);
}
}