aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r--delegate/src/DelegateUtils.hpp24
1 files changed, 20 insertions, 4 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp
index 46b2db9d64..58d8048be3 100644
--- a/delegate/src/DelegateUtils.hpp
+++ b/delegate/src/DelegateUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -404,14 +404,16 @@ armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor)
}
}
-armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor)
+armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor, bool isOutput = false)
{
armnn::DataType type = GetDataType(tfLiteTensor);
armnn::TensorInfo ret;
auto tensorDimensionSize = tfLiteTensor.dims->size;
if (tensorDimensionSize == 0)
{
- if(tflite::IsConstantTensor(&tfLiteTensor))
+ // If input tensor does not have a shape
+ // assuming that it has 1D tensor
+ if (!isOutput)
{
std::vector<unsigned int> safeShape = { 1 };
bool dimensionsSpecificity[1] = { true };
@@ -419,7 +421,10 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor)
safeShape.data(),
dimensionsSpecificity);
ret = armnn::TensorInfo(tensorShape, type);
- ret.SetConstant(true);
+ if(tflite::IsConstantTensor(&tfLiteTensor))
+ {
+ ret.SetConstant(true);
+ }
}
else
{
@@ -652,4 +657,15 @@ bool AreAllSigned32(const armnn::TensorInfo& inputInfo1,
(armnn::DataType::Signed32 == outputInfo.GetDataType());
}
+void UpdateConstantTensorOutputs(const armnn::TensorInfo& inputInfo, armnn::TensorInfo& outputInfo)
+{
+ // If input tensor info is constant and output tensor info shape is not specified
+ // set the output shape from input shape
+ if (inputInfo.IsConstant() && outputInfo.GetShape().GetDimensionality() == armnn::Dimensionality::NotSpecified)
+ {
+ outputInfo.SetShape(inputInfo.GetShape());
+ }
+ return;
+}
+
} // namespace anonymous