From c49aacc83370e89435129650a30ef1b384712dfe Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Fri, 28 Apr 2023 17:27:26 +0100 Subject: IVGCVSW-7603 Implement Reshape operators for Opaque Delegate * Moved CreateOutputTensorShape function to common DelegateUtils.hpp Signed-off-by: Matthew Sloyan Change-Id: I3d8a9834ecd6b7cda170cce958677a0dde62824a --- delegate/common/src/DelegateUtils.hpp | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'delegate/common') diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp index 51c70f9ba1..37fe9b5b84 100644 --- a/delegate/common/src/DelegateUtils.hpp +++ b/delegate/common/src/DelegateUtils.hpp @@ -21,6 +21,8 @@ #include #include +#include + namespace { @@ -138,4 +140,33 @@ void SetupConcatViewOrigin(const armnn::TensorInfo& inputTensorInfo, } } +TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo, + const std::vector& targetShape, + armnn::ReshapeDescriptor& reshapeDesc) +{ + std::vector outputDims(targetShape.begin(), targetShape.end()); + const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1); + + if (stretchDim != targetShape.end()) + { + if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end()) + { + // Return kTfLiteError and log the error after returning + return kTfLiteError; + } + + auto targetNumElements = + armnn::numeric_cast( + std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies())); + + auto stretchIndex = static_cast(std::distance(targetShape.begin(), stretchDim)); + outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + } + + armnn::TensorShape outputShape = armnn::TensorShape(static_cast(outputDims.size()), + outputDims.data()); + reshapeDesc.m_TargetShape = outputShape; + return kTfLiteOk; +} + } // namespace anonymous -- cgit v1.2.1