aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2018-10-01 11:51:37 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit479045bdcac4faddbf567aa0f73d2899881f341c (patch)
tree5cd11ee39543ceda2730d21c1e7036543712a335 /src/armnnTfLiteParser
parente4ba53a85c559d4fe574305276ac815cf7995762 (diff)
downloadarmnn-479045bdcac4faddbf567aa0f73d2899881f341c.tar.gz
IVGCVSW-1787 Add Support for Concatenation on TfLite parser
* Concatenation Parser function added to the TfLite Parser Change-Id: I42a42cd765ea09a30841c66b1942b9e09a876b10
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp98
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp5
-rw-r--r--src/armnnTfLiteParser/test/Concatenation.cpp187
3 files changed, 286 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 13e4604490..66746e488b 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -10,6 +10,7 @@
#include <boost/filesystem.hpp>
// armnnUtils:
+#include <ParserHelper.hpp>
#include <Permute.hpp>
#include <VerificationHelpers.hpp>
@@ -452,13 +453,14 @@ TfLiteParser::TfLiteParser()
{
// register supported operators
m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParser::ParseAveragePool2D;
+ m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParser::ParseConcatenation;
m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParser::ParseConv2D;
m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D] = &TfLiteParser::ParseDepthwiseConv2D;
- m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax;
- m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu;
m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6;
m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape;
+ m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax;
+ m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
}
void TfLiteParser::ResetParser()
@@ -1097,6 +1099,98 @@ void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+void TfLiteParser::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto * options = operatorPtr->builtin_options.AsConcatenationOptions();
+
+ CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ unsigned int numInputs = static_cast<unsigned int>(inputs.size());
+ unsigned int numConcatView = numInputs;
+
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
+ std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
+
+ unsigned int mergeDim = 0;
+
+ // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
+ // axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank,
+ // i.e., axis + rank(values)-th dimension.
+ int32_t inputRank = static_cast<int32_t>(ToTensorInfo(inputs[0]).GetNumDimensions());
+ const unsigned int concatDimInput = static_cast<unsigned int>((inputRank + options->axis) % inputRank);
+
+ // ArmNN supports concatenation along the channel dimension for data formats NHWC and NCHW.
+ if (concatDimInput == 0 || concatDimInput == 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Dimension %1% for concatenation is not supported by Armnn. "
+ "Node %2%")
+ % concatDimInput
+ % CHECK_LOCATION().AsString()));
+ }
+
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
+
+ // process the input tensor info
+ armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
+ concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+ }
+
+ auto layerName = boost::str(boost::format("Concatenation:%1%:%2%") % subgraphIndex % operatorIndex);
+ IConnectableLayer* layer = m_Network->AddMergerLayer(concatDescriptor, layerName.c_str());
+
+ BOOST_ASSERT(layer != nullptr);
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ if (concatDimInput == 3)
+ {
+ // Adding Fused Activation Layer after this moment....
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ // add permute layers to swizzle the inputs
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
+ IConnectableLayer* const swizzleLayer = SwizzleIn(*m_Network, layer, viewIndex, inputTensorInfo);
+
+ BOOST_ASSERT(swizzleLayer != nullptr);
+
+ // register the input connection slots for the layer
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, swizzleLayer, {inputTensorIndexes[viewIndex]});
+ }
+
+ // add permute layer to deswizzle the output
+ IConnectableLayer* const deswizzleLayer = DeswizzleOut(*m_Network, layer, 0, outputTensorInfo);
+
+ // add fused activation layer after the trailing swizzle layer
+ layer = AddFusedActivationLayer(deswizzleLayer, 0, options->fused_activation_function);
+ }
+ else
+ {
+ // set the layer output tensor info
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ // only the tensors for the inputs are relevant, exclude the const tensors
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
+ }
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index f949484a4f..620648a0c3 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -90,13 +90,14 @@ private:
void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
+ void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
- void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
- void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
+ void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
+ void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
diff --git a/src/armnnTfLiteParser/test/Concatenation.cpp b/src/armnnTfLiteParser/test/Concatenation.cpp
new file mode 100644
index 0000000000..8629efe3d7
--- /dev/null
+++ b/src/armnnTfLiteParser/test/Concatenation.cpp
@@ -0,0 +1,187 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
+
+struct ConcatenationFixture : public ParserFlatbuffersFixture
+{
+ explicit ConcatenationFixture(const std::string & inputShape1,
+ const std::string & inputShape2,
+ const std::string & outputShape,
+ const std::string & axis,
+ const std::string & activation="NONE")
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "CONCATENATION" } ],
+ "subgraphs": [ {
+ "tensors": [
+ {
+ "shape": )" + inputShape1 + R"(,
+ "type": "UINT8",
+ "buffer": 0,
+ "name": "inputTensor1",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + inputShape2 + R"(,
+ "type": "UINT8",
+ "buffer": 1,
+ "name": "inputTensor2",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + outputShape + R"( ,
+ "type": "UINT8",
+ "buffer": 2,
+ "name": "outputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ }
+ ],
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "builtin_options_type": "ConcatenationOptions",
+ "builtin_options": {
+ "axis": )" + axis + R"(,
+ "fused_activation_function": )" + activation + R"(
+ },
+ "custom_options_format": "FLEXBUFFERS"
+ }
+ ],
+ } ],
+ "buffers" : [
+ { },
+ { }
+ ]
+ }
+ )";
+ Setup();
+ }
+};
+
+
+struct ConcatenationFixtureNegativeDim : ConcatenationFixture
+{
+ ConcatenationFixtureNegativeDim() : ConcatenationFixture("[ 1, 1, 2, 2 ]",
+ "[ 1, 1, 2, 2 ]",
+ "[ 1, 2, 2, 2 ]",
+ "-3" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationNegativeDim, ConcatenationFixtureNegativeDim)
+{
+ RunTest<4, uint8_t>(0,
+ {{"inputTensor1", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 4, 5, 6, 7 }}},
+ {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
+}
+
+struct ConcatenationFixtureNCHW : ConcatenationFixture
+{
+ ConcatenationFixtureNCHW() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 2, 2, 2 ]", "1" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationNCHW, ConcatenationFixtureNCHW)
+{
+ RunTest<4, uint8_t>(0,
+ {{"inputTensor1", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 4, 5, 6, 7 }}},
+ {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
+}
+
+struct ConcatenationFixtureNHWC : ConcatenationFixture
+{
+ ConcatenationFixtureNHWC() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 4 ]", "3" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationNHWC, ConcatenationFixtureNHWC)
+{
+ RunTest<4, uint8_t>(0,
+ {{"inputTensor1", { 0, 1, 2, 3 }},
+ {"inputTensor2", { 4, 5, 6, 7 }}},
+ {{"outputTensor", { 0, 1, 4, 5, 2, 3, 6, 7 }}});
+}
+
+struct ConcatenationFixtureDim1 : ConcatenationFixture
+{
+ ConcatenationFixtureDim1() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 4, 3, 4 ]", "1" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim1, ConcatenationFixtureDim1)
+{
+ RunTest<4, uint8_t>(0,
+ { { "inputTensor1", { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 } },
+ { "inputTensor2", { 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
+ 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } },
+ { { "outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+ 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
+ 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } });
+}
+
+struct ConcatenationFixtureDim3 : ConcatenationFixture
+{
+ ConcatenationFixtureDim3() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 8 ]", "3" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim3, ConcatenationFixtureDim3)
+{
+ RunTest<4, uint8_t>(0,
+ { { "inputTensor1", { 0, 1, 2, 3,
+ 4, 5, 6, 7,
+ 8, 9, 10, 11,
+ 12, 13, 14, 15,
+ 16, 17, 18, 19,
+ 20, 21, 22, 23 } },
+ { "inputTensor2", { 50, 51, 52, 53,
+ 54, 55, 56, 57,
+ 58, 59, 60, 61,
+ 62, 63, 64, 65,
+ 66, 67, 68, 69,
+ 70, 71, 72, 73 } } },
+ { { "outputTensor", { 0, 1, 2, 3,
+ 50, 51, 52, 53,
+ 4, 5, 6, 7,
+ 54, 55, 56, 57,
+ 8, 9, 10, 11,
+ 58, 59, 60, 61,
+ 12, 13, 14, 15,
+ 62, 63, 64, 65,
+ 16, 17, 18, 19,
+ 66, 67, 68, 69,
+ 20, 21, 22, 23,
+ 70, 71, 72, 73 } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()