aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2018-10-01 11:51:37 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit479045bdcac4faddbf567aa0f73d2899881f341c (patch)
tree5cd11ee39543ceda2730d21c1e7036543712a335
parente4ba53a85c559d4fe574305276ac815cf7995762 (diff)
downloadarmnn-479045bdcac4faddbf567aa0f73d2899881f341c.tar.gz
IVGCVSW-1787 Add Support for Concatenation on TfLite parser
* Concatenation Parser function added to the TfLite Parser Change-Id: I42a42cd765ea09a30841c66b1942b9e09a876b10
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt7
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp98
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp5
-rw-r--r--src/armnnTfLiteParser/test/Concatenation.cpp187
-rw-r--r--src/armnnTfParser/TfParser.cpp39
-rw-r--r--src/armnnUtils/ParserHelper.cpp64
-rw-r--r--src/armnnUtils/ParserHelper.hpp17
8 files changed, 377 insertions, 41 deletions
diff --git a/Android.mk b/Android.mk
index 8f86ecbfb8..d181054525 100644
--- a/Android.mk
+++ b/Android.mk
@@ -73,6 +73,7 @@ LOCAL_SRC_FILES := \
src/armnnUtils/FloatingPointConverter.cpp \
src/armnnUtils/Logging.cpp \
src/armnnUtils/Permute.cpp \
+ src/armnnUtils/ParserHelper.cpp \
src/armnn/layers/ActivationLayer.cpp \
src/armnn/layers/AdditionLayer.cpp \
src/armnn/layers/ArithmeticBaseLayer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cfbbae6808..8d77d91c49 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -43,6 +43,8 @@ list(APPEND armnnUtils_sources
src/armnnUtils/FloatingPointConverter.hpp
src/armnnUtils/VerificationHelpers.hpp
src/armnnUtils/VerificationHelpers.cpp
+ src/armnnUtils/ParserHelper.hpp
+ src/armnnUtils/ParserHelper.cpp
)
if(BUILD_TF_PARSER OR BUILD_CAFFE_PARSER)
list(APPEND armnnUtils_sources
@@ -452,9 +454,12 @@ if(BUILD_UNIT_TESTS)
if(BUILD_TF_LITE_PARSER)
list(APPEND unittest_sources
src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+ src/armnnTfLiteParser/test/Activations.cpp
src/armnnTfLiteParser/test/AvgPool2D.cpp
+ src/armnnTfLiteParser/test/Concatenation.cpp
src/armnnTfLiteParser/test/Conv2D.cpp
src/armnnTfLiteParser/test/DepthwiseConvolution2D.cpp
+ src/armnnTfLiteParser/test/Reshape.cpp
src/armnnTfLiteParser/test/Softmax.cpp
src/armnnTfLiteParser/test/Squeeze.cpp
src/armnnTfLiteParser/test/LoadModel.cpp
@@ -464,8 +469,6 @@ if(BUILD_UNIT_TESTS)
src/armnnTfLiteParser/test/GetTensorIds.cpp
src/armnnTfLiteParser/test/GetSubgraphInputsOutputs.cpp
src/armnnTfLiteParser/test/GetInputsOutputs.cpp
- src/armnnTfLiteParser/test/Activations.cpp
- src/armnnTfLiteParser/test/Reshape.cpp
)
endif()
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 13e4604490..66746e488b 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -10,6 +10,7 @@
#include <boost/filesystem.hpp>
// armnnUtils:
+#include <ParserHelper.hpp>
#include <Permute.hpp>
#include <VerificationHelpers.hpp>
@@ -452,13 +453,14 @@ TfLiteParser::TfLiteParser()
{
// register supported operators
m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D;
+ m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation;
m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D;
m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D;
- m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax;
- m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu;
m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6;
m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape;
+ m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax;
+ m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
}
void TfLiteParser::ResetParser()
@@ -1097,6 +1099,98 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+void TfLiteParser::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto * options = operatorPtr->builtin_options.AsConcatenationOptions();
+
+ CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ unsigned int numInputs = static_cast<unsigned int>(inputs.size());
+ unsigned int numConcatView = numInputs;
+
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
+ std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
+
+ unsigned int mergeDim = 0;
+
+ // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
+ // axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank,
+ // i.e., axis + rank(values)-th dimension.
+ int32_t inputRank = static_cast<int32_t>(ToTensorInfo(inputs[0]).GetNumDimensions());
+ const unsigned int concatDimInput = static_cast<unsigned int>((inputRank + options->axis) % inputRank);
+
+ // ArmNN supports concatenation along the channel dimension for data formats NHWC and NCHW.
+ if (concatDimInput == 0 || concatDimInput == 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Dimension %1% for concatenation is not supported by Armnn. "
+ "Node %2%")
+ % concatDimInput
+ % CHECK_LOCATION().AsString()));
+ }
+
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
+
+ // process the input tensor info
+ armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
+ concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+ }
+
+ auto layerName = boost::str(boost::format("Concatenation:%1%:%2%") % subgraphIndex % operatorIndex);
+ IConnectableLayer* layer = m_Network->AddMergerLayer(concatDescriptor, layerName.c_str());
+
+ BOOST_ASSERT(layer != nullptr);
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ if (concatDimInput == 3)
+ {
+ // Adding Fused Activation Layer after this moment....
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ // add permute layers to swizzle the inputs
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
+ IConnectableLayer* const swizzleLayer = SwizzleIn(*m_Network, layer, viewIndex, inputTensorInfo);
+
+ BOOST_ASSERT(swizzleLayer != nullptr);
+
+ // register the input connection slots for the layer
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, swizzleLayer, {inputTensorIndexes[viewIndex]});
+ }
+
+ // add permute layer to deswizzle the output
+ IConnectableLayer* const deswizzleLayer = DeswizzleOut(*m_Network, layer, 0, outputTensorInfo);
+
+ // add fused activation layer after the trailing swizzle layer
+ layer = AddFusedActivationLayer(deswizzleLayer, 0, options->fused_activation_function);
+ }
+ else
+ {
+ // set the layer output tensor info
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
+ }
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index f949484a4f..620648a0c3 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -90,13 +90,14 @@ private:
void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
+ void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
- void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
- void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
+ void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
+ void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
diff --git a/src/armnnTfLiteParser/test/Concatenation.cpp b/src/armnnTfLiteParser/test/Concatenation.cpp
new file mode 100644
index 0000000000..8629efe3d7
--- /dev/null
+++ b/src/armnnTfLiteParser/test/Concatenation.cpp
@@ -0,0 +1,187 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
+
+struct ConcatenationFixture : public ParserFlatbuffersFixture
+{
+ explicit ConcatenationFixture(const std::string & inputShape1,
+ const std::string & inputShape2,
+ const std::string & outputShape,
+ const std::string & axis,
+ const std::string & activation="NONE")
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "CONCATENATION" } ],
+ "subgraphs": [ {
+ "tensors": [
+ {
+ "shape": )" + inputShape1 + R"(,
+ "type": "UINT8",
+ "buffer": 0,
+ "name": "inputTensor1",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + inputShape2 + R"(,
+ "type": "UINT8",
+ "buffer": 1,
+ "name": "inputTensor2",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + outputShape + R"( ,
+ "type": "UINT8",
+ "buffer": 2,
+ "name": "outputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ }
+ ],
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "builtin_options_type": "ConcatenationOptions",
+ "builtin_options": {
+ "axis": )" + axis + R"(,
+ "fused_activation_function": )" + activation + R"(
+ },
+ "custom_options_format": "FLEXBUFFERS"
+ }
+ ],
+ } ],
+ "buffers" : [
+ { },
+ { }
+ ]
+ }
+ )";
+ Setup();
+ }
+};
+
+
+struct ConcatenationFixtureNegativeDim : ConcatenationFixture
+{
+ ConcatenationFixtureNegativeDim() : ConcatenationFixture("[ 1, 1, 2, 2 ]",
+ "[ 1, 1, 2, 2 ]",
+ "[ 1, 2, 2, 2 ]",
+ "-3" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationNegativeDim, ConcatenationFixtureNegativeDim)
+{
+ RunTest<4, uint8_t>(0,
+ {{"inputTensor1", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 4, 5, 6, 7 }}},
+ {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
+}
+
+struct ConcatenationFixtureNCHW : ConcatenationFixture
+{
+ ConcatenationFixtureNCHW() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 2, 2, 2 ]", "1" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationNCHW, ConcatenationFixtureNCHW)
+{
+ RunTest<4, uint8_t>(0,
+ {{"inputTensor1", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 4, 5, 6, 7 }}},
+ {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
+}
+
+struct ConcatenationFixtureNHWC : ConcatenationFixture
+{
+ ConcatenationFixtureNHWC() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 4 ]", "3" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationNHWC, ConcatenationFixtureNHWC)
+{
+ RunTest<4, uint8_t>(0,
+ {{"inputTensor1", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 4, 5, 6, 7 }}},
+ {{"outputTensor", { 0, 1, 4, 5, 2, 3, 6, 7 }}});
+}
+
+struct ConcatenationFixtureDim1 : ConcatenationFixture
+{
+ ConcatenationFixtureDim1() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 4, 3, 4 ]", "1" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim1, ConcatenationFixtureDim1)
+{
+ RunTest<4, uint8_t>(0,
+ { { "inputTensor1", { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 } },
+ { "inputTensor2", { 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
+ 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } },
+ { { "outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+ 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
+ 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } });
+}
+
+struct ConcatenationFixtureDim3 : ConcatenationFixture
+{
+ ConcatenationFixtureDim3() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 8 ]", "3" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim3, ConcatenationFixtureDim3)
+{
+ RunTest<4, uint8_t>(0,
+ { { "inputTensor1", { 0, 1, 2, 3,
+ 4, 5, 6, 7,
+ 8, 9, 10, 11,
+ 12, 13, 14, 15,
+ 16, 17, 18, 19,
+ 20, 21, 22, 23 } },
+ { "inputTensor2", { 50, 51, 52, 53,
+ 54, 55, 56, 57,
+ 58, 59, 60, 61,
+ 62, 63, 64, 65,
+ 66, 67, 68, 69,
+ 70, 71, 72, 73 } } },
+ { { "outputTensor", { 0, 1, 2, 3,
+ 50, 51, 52, 53,
+ 4, 5, 6, 7,
+ 54, 55, 56, 57,
+ 8, 9, 10, 11,
+ 58, 59, 60, 61,
+ 12, 13, 14, 15,
+ 62, 63, 64, 65,
+ 16, 17, 18, 19,
+ 66, 67, 68, 69,
+ 20, 21, 22, 23,
+ 70, 71, 72, 73 } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index eca393663b..52ba92cad5 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -11,6 +11,7 @@
#include <armnn/Descriptors.hpp>
#include <GraphTopologicalSort.hpp>
+#include <ParserHelper.hpp>
#include <Permute.hpp>
#include <VerificationHelpers.hpp>
@@ -1478,41 +1479,9 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
- if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
- {
- throw ParseException(
- boost::str(
- boost::format(
- "The number of dimensions: %1% for input tensors of the "
- "concatenation op should be %2% for Node %3% %4%")
- % inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
- }
-
- if (concatDimInput == 3)
- {
- inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
- }
-
- for (unsigned int dim = 0; dim < MaxNumOfTensorDimensions; ++dim)
- {
- mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
- }
-
- for (unsigned int j = 0; j < concatDim; ++j)
- {
- concatDescriptor.SetViewOriginCoord(viewIndex, j, 0);
- }
-
- concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
- mergeDim += mergeDimSizes[concatDim];
-
- for (unsigned int j = concatDim+1; j < MaxNumOfTensorDimensions; ++j)
- {
- concatDescriptor.SetViewOriginCoord(viewIndex, j, 0);
- }
+ // process the input tensor info
+ armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
+ concatDimInput, viewIndex, mergeDimSizes, mergeDim);
}
mergeDimSizes[concatDim] = mergeDim;
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
new file mode 100644
index 0000000000..bf5ffdf0ad
--- /dev/null
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserHelper.hpp"
+
+// armnnUtils
+#include "Permute.hpp"
+
+#include <boost/format.hpp>
+
+namespace armnnUtils
+{
+
+const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
+const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
+
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis, unsigned int inputIndex,
+ std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
+{
+ // double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "concatenation op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % armnn::MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+
+ // if concatenation axis is 3 then need to be permuted
+ if (concatAxis == 3)
+ {
+ inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
+ }
+
+ for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
+ {
+ mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
+ }
+
+ // Concatenation dimension 1 is the only dimension supported in ArmNN
+ const unsigned int concatenationDim = 1;
+
+ for (unsigned int j = 0; j < concatenationDim; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+
+ concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
+ mergeDim += mergeDimSizes[concatenationDim];
+
+ for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+}
+
+} // namespace armnnUtils
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
new file mode 100644
index 0000000000..93dfbf9360
--- /dev/null
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -0,0 +1,17 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/ArmNN.hpp>
+
+namespace armnnUtils
+{
+
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis, unsigned int inputIndex,
+ std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim);
+
+} // namespace armnnUtils