aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r--delegate/src/DelegateUtils.hpp132
1 files changed, 113 insertions, 19 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp
index 729a8b4e98..fb3f998283 100644
--- a/delegate/src/DelegateUtils.hpp
+++ b/delegate/src/DelegateUtils.hpp
@@ -10,6 +10,8 @@
#include <armnn/utility/Assert.hpp>
#include <armnn/utility/NumericCast.hpp>
+#include <armnnUtils/Permute.hpp>
+
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
#include <tensorflow/lite/c/common.h>
@@ -94,6 +96,11 @@ TfLiteStatus ValidateNumOutputs(TfLiteContext* tfLiteContext,
return kTfLiteOk;
}
+bool IsValid(const TfLiteTensor* tfLiteTensor)
+{
+ return tfLiteTensor == nullptr ? false : true;
+}
+
bool IsDynamicTensor(const TfLiteTensor& tfLiteTensor)
{
auto tensorAllocationType = tfLiteTensor.allocation_type;
@@ -118,13 +125,15 @@ TfLiteStatus Connect(armnn::IConnectableLayer* layer,
TfLiteNode* tfLiteNode,
armnnDelegate::DelegateData& data)
{
- ARMNN_ASSERT(tfLiteNode->inputs->size == layer->GetNumInputSlots());
ARMNN_ASSERT(tfLiteNode->outputs->size == layer->GetNumOutputSlots());
// Connect the input slots
for (unsigned int inputIndex = 0; inputIndex < layer->GetNumInputSlots(); ++inputIndex)
{
- data.m_OutputSlotForNode[tfLiteNode->inputs->data[inputIndex]]->Connect(layer->GetInputSlot(inputIndex));
+ if (data.m_OutputSlotForNode[tfLiteNode->inputs->data[inputIndex]] != nullptr)
+ {
+ data.m_OutputSlotForNode[tfLiteNode->inputs->data[inputIndex]]->Connect(layer->GetInputSlot(inputIndex));
+ }
}
// Prepare output slots
@@ -133,6 +142,7 @@ TfLiteStatus Connect(armnn::IConnectableLayer* layer,
armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex);
data.m_OutputSlotForNode[tfLiteNode->outputs->data[outputIndex]] = &outputSlot;
}
+
return kTfLiteOk;
}
@@ -299,43 +309,39 @@ TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext,
return kTfLiteOk;
}
-armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor)
+armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor)
{
- armnn::DataType type;
switch (tfLiteTensor.type)
{
case kTfLiteBool:
- type = armnn::DataType::Boolean;
- break;
+ return armnn::DataType::Boolean;
case kTfLiteFloat32:
- type = armnn::DataType::Float32;
- break;
+ return armnn::DataType::Float32;
case kTfLiteFloat16:
- type = armnn::DataType::Float16;
- break;
+ return armnn::DataType::Float16;
case kTfLiteUInt8:
- type = armnn::DataType::QAsymmU8;
- break;
+ return armnn::DataType::QAsymmU8;
case kTfLiteInt8:
if (tfLiteTensor.params.zero_point == 0)
{
- type = armnn::DataType::QSymmS8;
+ return armnn::DataType::QSymmS8;
}
else
{
- type = armnn::DataType::QAsymmS8;
+ return armnn::DataType::QAsymmS8;
}
- break;
case kTfLiteInt16:
- type = armnn::DataType::QSymmS16;
- break;
+ return armnn::DataType::QSymmS16;
case kTfLiteInt32:
- type = armnn::DataType::Signed32;
- break;
+ return armnn::DataType::Signed32;
default:
throw armnn::Exception("TfLiteArmnnDelegate: Unsupported data type: " + tfLiteTensor.type);
}
+}
+armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor)
+{
+ armnn::DataType type = GetDataType(tfLiteTensor);
armnn::TensorInfo ret;
auto tensorDimensionSize = tfLiteTensor.dims->size;
if (tensorDimensionSize == 0)
@@ -391,4 +397,92 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor)
return ret;
}
+struct DataHolder
+{
+public:
+ DataHolder()
+ : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
+ m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
+
+ DataHolder(std::unique_ptr<float[]>&& data)
+ : m_Fp32Data(std::move(data)), m_Uint8Data(nullptr),
+ m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
+
+ DataHolder(std::unique_ptr<uint8_t[]>&& data)
+ : m_Fp32Data(nullptr), m_Uint8Data(std::move(data)),
+ m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
+
+ DataHolder(std::unique_ptr<int8_t[]>&& data)
+ : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
+ m_Int8Data(std::move(data)), m_Int16Data(nullptr), m_Int32Data(nullptr) {}
+
+ DataHolder(std::unique_ptr<int16_t[]>&& data)
+ : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
+ m_Int8Data(nullptr), m_Int16Data(std::move(data)), m_Int32Data(nullptr) {}
+
+ DataHolder(std::unique_ptr<int32_t[]>&& data)
+ : m_Fp32Data(nullptr), m_Uint8Data(nullptr),
+ m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(std::move(data)) {}
+
+private:
+ std::unique_ptr<float[]> m_Fp32Data;
+ std::unique_ptr<uint8_t[]> m_Uint8Data;
+ std::unique_ptr<int8_t[]> m_Int8Data;
+ std::unique_ptr<int16_t[]> m_Int16Data;
+ std::unique_ptr<int32_t[]> m_Int32Data;
+};
+
+template <typename T>
+std::pair<armnn::ConstTensor, DataHolder> CreateConstTensorImpl(
+ const TfLiteTensor* tensor,
+ armnn::TensorInfo& tensorInfo,
+ armnn::Optional<armnn::PermutationVector&> permutationVector)
+{
+ std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
+ if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
+ {
+ tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
+ armnnUtils::Permute(tensorInfo.GetShape(),
+ permutationVector.value(),
+ reinterpret_cast<const T*>(tensor->data.raw), data.get(), sizeof(T));
+ }
+ else
+ {
+ ::memcpy(data.get(), tensor->data.raw, tensorInfo.GetNumBytes());
+ }
+
+ auto constData = std::make_pair(armnn::ConstTensor(tensorInfo, data.get()), std::move(data));
+
+ DataHolder storedData(std::move(constData.second));
+ return std::make_pair(constData.first, std::move(storedData));
+}
+
+std::pair<armnn::ConstTensor, DataHolder> CreateConstTensor(
+ const TfLiteTensor* tfLiteTensor,
+ armnn::TensorInfo& tensorInfo,
+ armnn::Optional<armnn::PermutationVector&> permutationVector)
+{
+ switch (tensorInfo.GetDataType())
+ {
+ case armnn::DataType::Float32:
+ return CreateConstTensorImpl<float>(tfLiteTensor, tensorInfo, permutationVector);
+ case armnn::DataType::QAsymmU8:
+ return CreateConstTensorImpl<uint8_t>(tfLiteTensor, tensorInfo, permutationVector);
+ case armnn::DataType::QSymmS8:
+ return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector);
+ case armnn::DataType::QAsymmS8:
+ return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector);
+ case armnn::DataType::QSymmS16:
+ return CreateConstTensorImpl<int16_t>(tfLiteTensor, tensorInfo, permutationVector);
+ case armnn::DataType::Signed32:
+ return CreateConstTensorImpl<int32_t>(tfLiteTensor, tensorInfo, permutationVector);
+ default:
+ {
+ throw armnn::Exception(
+ "TfLiteArmnnDelegate: Unsupported data type when creating const tensor: "
+ + std::string(armnn::GetDataTypeName(tensorInfo.GetDataType())));
+ }
+ }
+}
+
} // namespace anonymous