aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/ParserHelper.cpp64
-rw-r--r--src/armnnUtils/ParserHelper.hpp17
2 files changed, 81 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
new file mode 100644
index 0000000000..bf5ffdf0ad
--- /dev/null
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserHelper.hpp"
+
+// armnnUtils
+#include "Permute.hpp"
+
+#include <boost/format.hpp>
+
+namespace armnnUtils
+{
+
+const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
+const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
+
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis, unsigned int inputIndex,
+ std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
+{
+ // double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "concatenation op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % armnn::MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+
+ // if concatenation axis is 3 then need to be permuted
+ if (concatAxis == 3)
+ {
+ inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
+ }
+
+ for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
+ {
+ mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
+ }
+
+ // Concatenation dimension 1 is the only dimension supported in ArmNN
+ const unsigned int concatenationDim = 1;
+
+ for (unsigned int j = 0; j < concatenationDim; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+
+ concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
+ mergeDim += mergeDimSizes[concatenationDim];
+
+ for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+}
+
+} // namespace armnnUtils
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
new file mode 100644
index 0000000000..93dfbf9360
--- /dev/null
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -0,0 +1,17 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/ArmNN.hpp>
+
+namespace armnnUtils
+{
+
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis, unsigned int inputIndex,
+ std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim);
+
+} // namespace armnnUtils