diff options
Diffstat (limited to 'delegate/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/src/DelegateUtils.hpp | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/delegate/src/DelegateUtils.hpp b/delegate/src/DelegateUtils.hpp index 46b2db9d64..58d8048be3 100644 --- a/delegate/src/DelegateUtils.hpp +++ b/delegate/src/DelegateUtils.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -404,14 +404,16 @@ armnn::DataType GetDataType(const TfLiteTensor& tfLiteTensor) } } -armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) +armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor, bool isOutput = false) { armnn::DataType type = GetDataType(tfLiteTensor); armnn::TensorInfo ret; auto tensorDimensionSize = tfLiteTensor.dims->size; if (tensorDimensionSize == 0) { - if(tflite::IsConstantTensor(&tfLiteTensor)) + // If input tensor does not have a shape + // assuming that it has 1D tensor + if (!isOutput) { std::vector<unsigned int> safeShape = { 1 }; bool dimensionsSpecificity[1] = { true }; @@ -419,7 +421,10 @@ armnn::TensorInfo GetTensorInfoForTfLiteTensor(const TfLiteTensor& tfLiteTensor) safeShape.data(), dimensionsSpecificity); ret = armnn::TensorInfo(tensorShape, type); - ret.SetConstant(true); + if(tflite::IsConstantTensor(&tfLiteTensor)) + { + ret.SetConstant(true); + } } else { @@ -652,4 +657,15 @@ bool AreAllSigned32(const armnn::TensorInfo& inputInfo1, (armnn::DataType::Signed32 == outputInfo.GetDataType()); } +void UpdateConstantTensorOutputs(const armnn::TensorInfo& inputInfo, armnn::TensorInfo& outputInfo) +{ + // If input tensor info is constant and output tensor info shape is not specified + // set the output shape from input shape + if (inputInfo.IsConstant() && outputInfo.GetShape().GetDimensionality() == armnn::Dimensionality::NotSpecified) + { + outputInfo.SetShape(inputInfo.GetShape()); + } + return; +} + } // namespace anonymous |