aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2018-10-01 11:51:37 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit479045bdcac4faddbf567aa0f73d2899881f341c (patch)
tree5cd11ee39543ceda2730d21c1e7036543712a335 /src/armnnUtils
parente4ba53a85c559d4fe574305276ac815cf7995762 (diff)
downloadarmnn-479045bdcac4faddbf567aa0f73d2899881f341c.tar.gz
IVGCVSW-1787 Add Support for Concatenation on TfLite parser
* Concatenation Parser function added to the TfLite Parser Change-Id: I42a42cd765ea09a30841c66b1942b9e09a876b10
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/ParserHelper.cpp64
-rw-r--r--src/armnnUtils/ParserHelper.hpp17
2 files changed, 81 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
new file mode 100644
index 0000000000..bf5ffdf0ad
--- /dev/null
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserHelper.hpp"
+
+// armnnUtils
+#include "Permute.hpp"
+
+#include <boost/format.hpp>
+
+namespace armnnUtils
+{
+
+const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
+const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
+
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis, unsigned int inputIndex,
+ std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
+{
+ // double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "concatenation op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % armnn::MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+
+ // if concatenation axis is 3 then need to be permuted
+ if (concatAxis == 3)
+ {
+ inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
+ }
+
+ for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
+ {
+ mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
+ }
+
+ // Concatenation dimension 1 is the only dimension supported in ArmNN
+ const unsigned int concatenationDim = 1;
+
+ for (unsigned int j = 0; j < concatenationDim; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+
+ concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
+ mergeDim += mergeDimSizes[concatenationDim];
+
+ for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+}
+
+} // namespace armnnUtils
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
new file mode 100644
index 0000000000..93dfbf9360
--- /dev/null
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -0,0 +1,17 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/ArmNN.hpp>
+
+namespace armnnUtils
+{
+
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis, unsigned int inputIndex,
+ std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim);
+
+} // namespace armnnUtils