diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-11-10 21:18:41 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-11-10 21:18:41 +0000 |
commit | 6e36a64e26520e3f169bb2a92972a24e1be915a7 (patch) | |
tree | 5b1f9187fc297d34cb34328ac6aa447b9026e657 /delegate/src/DelegateUtils.hpp | |
parent | 50c87d39173cb48fc216ccb585714b669b095611 (diff) | |
download | armnn-6e36a64e26520e3f169bb2a92972a24e1be915a7.tar.gz |
IVGCVSW-5389 'TfLiteDelegate: Implement the FullyConnected operator'
* Added FullyConnected operator support to delegate
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Iae9c0980a4bfd6aa4d90f107f329dfa782baeefe
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 132 |
1 files changed, 113 insertions, 19 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index 729a8b4e98..fb3f998283 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -10,6 +10,8 @@ #include <armnn/utility/Assert.hpp> #include <armnn/utility/NumericCast.hpp> +#include <armnnUtils/Permute.hpp> + #include <tensorflow/lite/builtin_ops.h> #include <tensorflow/lite/c/builtin_op_data.h> #include <tensorflow/lite/c/common.h> @@ -94,6 +96,11 @@ TfLiteStatus ValidateNumOutputs(TfLiteContext* tfLiteContext, return kTfLiteOk; } +bool IsValid(const TfLiteTensor* tfLiteTensor) +{ + return tfLiteTensor == nullptr ? false : true; +} + bool IsDynamicTensor(const TfLiteTensor& tfLiteTensor) { auto tensorAllocationType = tfLiteTensor.allocation_type; @@ -118,13 +125,15 @@ TfLiteStatus Connect(armnn::IConnectableLayer* layer, TfLiteNode* tfLiteNode, armnnDelegate::DelegateData& data) { - ARMNN_ASSERT(tfLiteNode->inputs->size == layer->GetNumInputSlots()); ARMNN_ASSERT(tfLiteNode->outputs->size == layer->GetNumOutputSlots()); // Connect the input slots for (unsigned int inputIndex = 0; inputIndex < layer->GetNumInputSlots(); ++inputIndex) { - data.m_OutputSlotForNode[tfLiteNode->inputs->data[inputIndex]]->Connect(layer->GetInputSlot(inputIndex)); + if (data.m_OutputSlotForNode[tfLiteNode->inputs->data[inputIndex]] != nullptr) + { + data.m_OutputSlotForNode[tfLiteNode->inputs->data[inputIndex]]->Connect(layer->GetInputSlot(inputIndex)); + } } // Prepare output slots @@ -133,6 +142,7 @@ TfLiteStatus Connect(armnn::IConnectableLayer* layer, armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(outputIndex); data.m_OutputSlotForNode[tfLiteNode->outputs->data[outputIndex]] = &outputSlot; } + return kTfLiteOk; } @@ -299,43 +309,39 @@ TfLiteStatus FusedActivation(TfLiteContext* tfLiteContext, return kTfLiteOk; } -armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) +armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) { - armnn::DataType type; switch (tfLiteTensor.type) { case kTfLiteBool: - type = armnn::DataType::Boolean; - break; + return armnn::DataType::Boolean; case kTfLiteFloat32: - type = armnn::DataType::Float32; - break; + return armnn::DataType::Float32; case kTfLiteFloat16: - type = armnn::DataType::Float16; - break; + return armnn::DataType::Float16; case kTfLiteUInt8: - type = armnn::DataType::QAsymmU8; - break; + return armnn::DataType::QAsymmU8; case kTfLiteInt8: if (tfLiteTensor.params.zero_point == 0) { - type = armnn::DataType::QSymmS8; + return armnn::DataType::QSymmS8; } else { - type = armnn::DataType::QAsymmS8; + return armnn::DataType::QAsymmS8; } - break; case kTfLiteInt16: - type = armnn::DataType::QSymmS16; - break; + return armnn::DataType::QSymmS16; case kTfLiteInt32: - type = armnn::DataType::Signed32; - break; + return armnn::DataType::Signed32; default: throw armnn::Exception("TfLiteArmnnDelegate: Unsupported data type: " + tfLiteTensor.type); } +} +armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) +{ + armnn::DataType type = GetDataType(tfLiteTensor); armnn::TensorInfo ret; auto tensorDimensionSize = tfLiteTensor.dims->size; if (tensorDimensionSize == 0) @@ -391,4 +397,92 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) return ret; } +struct DataHolder +{ +public: + DataHolder() + : m_Fp32Data(nullptr), m_Uint8Data(nullptr), + m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} + + DataHolder(std::unique_ptr<float[]>&& data) + : m_Fp32Data(std::move(data)), m_Uint8Data(nullptr), + m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} + + DataHolder(std::unique_ptr<uint8_t[]>&& data) + : m_Fp32Data(nullptr), m_Uint8Data(std::move(data)), + m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(nullptr) {} + + DataHolder(std::unique_ptr<int8_t[]>&& data) + : m_Fp32Data(nullptr), m_Uint8Data(nullptr), + m_Int8Data(std::move(data)), m_Int16Data(nullptr), m_Int32Data(nullptr) {} + + DataHolder(std::unique_ptr<int16_t[]>&& data) + : m_Fp32Data(nullptr), m_Uint8Data(nullptr), + m_Int8Data(nullptr), m_Int16Data(std::move(data)), m_Int32Data(nullptr) {} + + DataHolder(std::unique_ptr<int32_t[]>&& data) + : m_Fp32Data(nullptr), m_Uint8Data(nullptr), + m_Int8Data(nullptr), m_Int16Data(nullptr), m_Int32Data(std::move(data)) {} + +private: + std::unique_ptr<float[]> m_Fp32Data; + std::unique_ptr<uint8_t[]> m_Uint8Data; + std::unique_ptr<int8_t[]> m_Int8Data; + std::unique_ptr<int16_t[]> m_Int16Data; + std::unique_ptr<int32_t[]> m_Int32Data; +}; + +template <typename T> +std::pair<armnn::ConstTensor, DataHolder> CreateConstTensorImpl( + const TfLiteTensor* tensor, + armnn::TensorInfo& tensorInfo, + armnn::Optional<armnn::PermutationVector&> permutationVector) +{ + std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]); + if (permutationVector.has_value() && permutationVector.value().GetSize() > 0) + { + tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value()); + armnnUtils::Permute(tensorInfo.GetShape(), + permutationVector.value(), + reinterpret_cast<const T*>(tensor->data.raw), data.get(), sizeof(T)); + } + else + { + ::memcpy(data.get(), tensor->data.raw, tensorInfo.GetNumBytes()); + } + + auto constData = std::make_pair(armnn::ConstTensor(tensorInfo, data.get()), std::move(data)); + + DataHolder storedData(std::move(constData.second)); + return std::make_pair(constData.first, std::move(storedData)); +} + +std::pair<armnn::ConstTensor, DataHolder> CreateConstTensor( + const TfLiteTensor* tfLiteTensor, + armnn::TensorInfo& tensorInfo, + armnn::Optional<armnn::PermutationVector&> permutationVector) +{ + switch (tensorInfo.GetDataType()) + { + case armnn::DataType::Float32: + return CreateConstTensorImpl<float>(tfLiteTensor, tensorInfo, permutationVector); + case armnn::DataType::QAsymmU8: + return CreateConstTensorImpl<uint8_t>(tfLiteTensor, tensorInfo, permutationVector); + case armnn::DataType::QSymmS8: + return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector); + case armnn::DataType::QAsymmS8: + return CreateConstTensorImpl<int8_t>(tfLiteTensor, tensorInfo, permutationVector); + case armnn::DataType::QSymmS16: + return CreateConstTensorImpl<int16_t>(tfLiteTensor, tensorInfo, permutationVector); + case armnn::DataType::Signed32: + return CreateConstTensorImpl<int32_t>(tfLiteTensor, tensorInfo, permutationVector); + default: + { + throw armnn::Exception( + "TfLiteArmnnDelegate: Unsupported data type when creating const tensor: " + + std::string(armnn::GetDataTypeName(tensorInfo.GetDataType()))); + } + } +} + } // namespace anonymous |