aboutsummaryrefslogtreecommitdiff
path: root/delegate/common/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/common/src/DelegateUtils.hpp')
-rw-r--r--delegate/common/src/DelegateUtils.hpp31
1 files changed, 31 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp
index 51c70f9ba1..37fe9b5b84 100644
--- a/delegate/common/src/DelegateUtils.hpp
+++ b/delegate/common/src/DelegateUtils.hpp
@@ -21,6 +21,8 @@
#include <tensorflow/lite/minimal_logging.h>
#include <tensorflow/lite/kernels/kernel_util.h>
+#include <numeric>
+
namespace
{
@@ -138,4 +140,33 @@ void SetupConcatViewOrigin(const armnn::TensorInfo& inputTensorInfo,
}
}
+TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
+ const std::vector<int32_t>& targetShape,
+ armnn::ReshapeDescriptor& reshapeDesc)
+{
+ std::vector<unsigned int> outputDims(targetShape.begin(), targetShape.end());
+ const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1);
+
+ if (stretchDim != targetShape.end())
+ {
+ if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end())
+ {
+ // Return kTfLiteError and log the error after returning
+ return kTfLiteError;
+ }
+
+ auto targetNumElements =
+ armnn::numeric_cast<unsigned int>(
+ std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
+
+ auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
+
+ armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
+ outputDims.data());
+ reshapeDesc.m_TargetShape = outputShape;
+ return kTfLiteOk;
+}
+
} // namespace anonymous