aboutsummaryrefslogtreecommitdiff
path: root/OutputShapeUtils.cpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-07-09 17:44:24 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-07-09 16:55:01 +0000
commitf03fcf0dd180ba2c87648a524fcca9214e1f979b (patch)
tree99eb4643dec16db121b8702a4829f9c8323f5b45 /OutputShapeUtils.cpp
parent177fa0ba936eaf9de96f04bb91aa51d7656dd655 (diff)
downloadandroid-nn-driver-f03fcf0dd180ba2c87648a524fcca9214e1f979b.tar.gz
IVGCVSW-3456 Add support for dynamic output shape in ConvertPrelu
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I8fc7a716455be3f51b51177f6896a73790a41fc3
Diffstat (limited to 'OutputShapeUtils.cpp')
-rw-r--r--OutputShapeUtils.cpp43
1 files changed, 43 insertions, 0 deletions
diff --git a/OutputShapeUtils.cpp b/OutputShapeUtils.cpp
new file mode 100644
index 00000000..de27630e
--- /dev/null
+++ b/OutputShapeUtils.cpp
@@ -0,0 +1,43 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "OutputShapeUtils.hpp"
+
+#include <algorithm>
+
+namespace armnn_driver
+{
+
+using namespace armnn;
+
+TensorShape InferPreluOutputShape(const TensorShape& inputShape, const TensorShape& alphaShape)
+{
+ // NOTE: The inferred PReLU output size will be the maximum size along each dimension
+ // of input and alpha, starting with the trailing dimensions, and working its way forward.
+ //
+ // Example: inputShape={4, 1, 2}, alphaShape={5, 4, 3, 1} => outputShape={5, 4, 3, 2}
+
+ const unsigned int numInputDims = inputShape.GetNumDimensions();
+ const unsigned int numAlphaDims = alphaShape.GetNumDimensions();
+
+ const unsigned int maxNumDims = std::max(numInputDims, numAlphaDims);
+
+ TensorShape outputShape = TensorShape(maxNumDims);
+ for (unsigned int reverseIdx = 1u; reverseIdx <= maxNumDims; ++reverseIdx)
+ {
+ const int inputIdx = numInputDims - reverseIdx;
+ const int alphaIdx = numAlphaDims - reverseIdx;
+
+ const unsigned int inputDimSize = inputIdx >= 0 ? inputShape[inputIdx] : 0u;
+ const unsigned int alphaDimSize = alphaIdx >= 0 ? alphaShape[alphaIdx] : 0u;
+
+ const unsigned int outputIdx = maxNumDims - reverseIdx;
+ outputShape[outputIdx] = std::max(inputDimSize, alphaDimSize);
+ }
+
+ return outputShape;
+}
+
+} // namespace armnn_driver \ No newline at end of file