aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r--delegate/src/DelegateUtils.hpp47
1 files changed, 45 insertions, 2 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp
index e9f579b699..0537ba911b 100644
--- a/delegate/src/DelegateUtils.hpp
+++ b/delegate/src/DelegateUtils.hpp
@@ -17,6 +17,8 @@
#include <tensorflow/lite/c/common.h>
#include <tensorflow/lite/minimal_logging.h>
+#include "tensorflow/lite/kernels/kernel_util.h"
+
namespace
{
@@ -365,8 +367,20 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor,
auto tensorDimensionSize = tfLiteTensor.dims->size;
if (tensorDimensionSize == 0)
{
- armnn::TensorShape tensorShape(armnn::Dimensionality::NotSpecified);
- ret = armnn::TensorInfo(tensorShape, type);
+ if(tflite::IsConstantTensor(&tfLiteTensor))
+ {
+ std::vector<unsigned int> safeShape = { 1 };
+ bool dimensionsSpecificity[1] = { true };
+ armnn::TensorShape tensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
+ dimensionsSpecificity);
+ ret = armnn::TensorInfo(tensorShape, type);
+ }
+ else
+ {
+ armnn::TensorShape tensorShape(armnn::Dimensionality::NotSpecified);
+ ret = armnn::TensorInfo(tensorShape, type);
+ }
}
else
{
@@ -468,6 +482,35 @@ void CalcPadding(uint32_t inputSize,
}
}
+TfLiteStatus ConnectConstant(armnn::IConnectableLayer* layer,
+ armnn::TensorInfo& constTensorInfo,
+ TfLiteContext* tfLiteContext,
+ const TfLiteTensor& tfLiteTensor,
+ armnnDelegate::DelegateData& data,
+ unsigned int slotIndex)
+{
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ tfLiteContext,
+ IsConstantSupported,
+ data.m_Backends,
+ isSupported,
+ constTensorInfo);
+ if (!isSupported)
+ {
+ return kTfLiteError;
+ }
+
+ auto constantInput = CreateConstTensor(&tfLiteTensor,
+ constTensorInfo,
+ armnn::Optional<armnn::PermutationVector&>());
+ armnn::IConnectableLayer* constantLayer = data.m_Network->AddConstantLayer(constantInput);
+ armnn::IOutputSlot& outputSlot = constantLayer->GetOutputSlot(0);
+ outputSlot.SetTensorInfo(constTensorInfo);
+
+ data.m_OutputSlotForNode[static_cast<unsigned long>(slotIndex)] = &outputSlot;
+ return kTfLiteOk;
+}
} // namespace anonymous