aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-25 13:20:39 +0000
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-25 13:20:39 +0000
commit5e9d29802e2cfbb13adc49c2a0ac9ba952dc7650 (patch)
tree31baaf01ff767159fb6a4e594405b4acf03a6f51
parent6e2f60674cbe77c2a1da94ab71e35c298a1924de (diff)
downloadarmnn-5e9d29802e2cfbb13adc49c2a0ac9ba952dc7650.tar.gz
IVGCVSW-2563 Fix bug in TfLiteParser::ParseConcatenation
Change-Id: I8fbf27b383a821e062f72809cc2e269fcd18851c
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp110
-rw-r--r--src/armnnTfLiteParser/test/Concatenation.cpp51
-rw-r--r--src/armnnUtils/ParserHelper.cpp36
-rw-r--r--src/armnnUtils/ParserHelper.hpp8
4 files changed, 83 insertions, 122 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 359695b94d..8b2a818e6d 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -33,48 +33,6 @@ namespace armnnTfLiteParser
{
namespace
{
-const PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
-const PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
-
-IConnectableLayer* SwizzleIn(INetwork& network,
- IConnectableLayer* layer,
- unsigned int inputSlotIndex,
- const TensorInfo & inputInfo)
-{
- BOOST_ASSERT(layer != nullptr);
- // Add swizzle layer
- std::stringstream name;
- name << "swizzle_for-" << layer->GetName() << ":in" << inputSlotIndex;
- IConnectableLayer* const swizzleLayer = network.AddPermuteLayer(NHWCToArmNN, name.str().c_str());
- // Set swizzled output shape
- const TensorInfo swizzleOutInfo = armnnUtils::Permuted(inputInfo, NHWCToArmNN);
- swizzleLayer->GetOutputSlot(0).SetTensorInfo(swizzleOutInfo);
- // Connect the swizzle layer to the actual layer
- swizzleLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(inputSlotIndex));
-
- return swizzleLayer;
-}
-
-IConnectableLayer* DeswizzleOut(INetwork& network,
- IConnectableLayer* layer,
- unsigned int outputSlotIndex,
- const TensorInfo & outputInfo)
-{
- BOOST_ASSERT(layer != nullptr);
- // Add deswizzle layer
- std::stringstream name;
- name << "deswizzle_for-" << layer->GetName() << ":out" << outputSlotIndex;
- IConnectableLayer* const deswizzleLayer = network.AddPermuteLayer(ArmNNToNHWC, name.str().c_str());
- // Set deswizzled output shape
- deswizzleLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
- // Set original layer output shape
- const TensorInfo deswizzleOutInfo = armnnUtils::Permuted(outputInfo, NHWCToArmNN);
- layer->GetOutputSlot(outputSlotIndex).SetTensorInfo(deswizzleOutInfo);
- // Connect the actual layer to the deswizzle layer
- layer->GetOutputSlot(outputSlotIndex).Connect(deswizzleLayer->GetInputSlot(0));
-
- return deswizzleLayer;
-}
const uint32_t VIRTUAL_OPERATOR_ID = std::numeric_limits<uint32_t>::max();
@@ -1383,39 +1341,24 @@ void TfLiteParser::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- unsigned int numInputs = static_cast<unsigned int>(inputs.size());
- unsigned int numConcatView = numInputs;
+ unsigned int numConcatView = static_cast<unsigned int>(inputs.size());
+ uint32_t inputRank = ToTensorInfo(inputs[0]).GetNumDimensions();
- OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
- std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
+ const unsigned int concatDimInput = static_cast<unsigned int>(
+ (static_cast<int>(inputRank) + options->axis) % static_cast<int>(inputRank));
- unsigned int mergeDim = 0;
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), inputRank);
+ concatDescriptor.SetConcatAxis(concatDimInput);
- // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
- // axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank,
- // i.e., axis + rank(values)-th dimension.
- int32_t inputRank = static_cast<int32_t>(ToTensorInfo(inputs[0]).GetNumDimensions());
- const unsigned int concatDimInput = static_cast<unsigned int>((inputRank + options->axis) % inputRank);
-
- // ArmNN supports concatenation along the channel dimension for data formats NHWC and NCHW.
- if (concatDimInput == 0 || concatDimInput == 2)
- {
- throw ParseException(
- boost::str(
- boost::format(
- "Dimension %1% for concatenation is not supported by Armnn. "
- "Node %2%")
- % concatDimInput
- % CHECK_LOCATION().AsString()));
- }
+ unsigned int mergeDimOrigin = 0;
for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
{
TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
- // process the input tensor info
- armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
- concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+ // This set up concatDescriptor view origin
+ armnnUtils::ProcessConcatInputTensorInfo(
+ inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
}
auto layerName = boost::str(boost::format("Concatenation:%1%:%2%") % subgraphIndex % operatorIndex);
@@ -1425,39 +1368,14 @@ void TfLiteParser::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
- if (concatDimInput == 3)
- {
- // Adding Fused Activation Layer after this moment....
- for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
- {
- // add permute layers to swizzle the inputs
- armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
- IConnectableLayer* const swizzleLayer = SwizzleIn(*m_Network, layer, viewIndex, inputTensorInfo);
-
- BOOST_ASSERT(swizzleLayer != nullptr);
-
- // register the input connection slots for the layer
- // only the tensors for the inputs are relevant, exclude the const tensors
- RegisterInputSlots(subgraphIndex, operatorIndex, swizzleLayer, {inputTensorIndexes[viewIndex]});
- }
- // add permute layer to deswizzle the output
- IConnectableLayer* const deswizzleLayer = DeswizzleOut(*m_Network, layer, 0, outputTensorInfo);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
- // add fused activation layer after the trailing swizzle layer
- layer = AddFusedActivationLayer(deswizzleLayer, 0, options->fused_activation_function);
- }
- else
- {
- // set the layer output tensor info
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
- // register the input connection slots for the layer, connections are made after all layers have been created
- // only the tensors for the inputs are relevant, exclude the const tensors
- RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
- }
+ // add fused activation layer
+ layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
- // register the output connection slots for the layer, connections are made after all layers have been created
auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
diff --git a/src/armnnTfLiteParser/test/Concatenation.cpp b/src/armnnTfLiteParser/test/Concatenation.cpp
index bb5aebf39c..d3d571f174 100644
--- a/src/armnnTfLiteParser/test/Concatenation.cpp
+++ b/src/armnnTfLiteParser/test/Concatenation.cpp
@@ -189,4 +189,55 @@ BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim3, ConcatenationFixtureDim3)
70, 71, 72, 73 } } });
}
+struct ConcatenationFixture3DDim0 : ConcatenationFixture
+{
+ ConcatenationFixture3DDim0() : ConcatenationFixture("[ 1, 2, 3]", "[ 2, 2, 3]", "[ 3, 2, 3]", "0" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenation3DDim0, ConcatenationFixture3DDim0)
+{
+ RunTest<3, armnn::DataType::QuantisedAsymm8>(
+ 0,
+ { { "inputTensor1", { 0, 1, 2, 3, 4, 5 } },
+ { "inputTensor2", { 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17 } } },
+ { { "outputTensor", { 0, 1, 2, 3, 4, 5,
+ 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17 } } });
+}
+
+struct ConcatenationFixture3DDim1 : ConcatenationFixture
+{
+ ConcatenationFixture3DDim1() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 4, 3]", "[ 1, 6, 3]", "1" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenation3DDim1, ConcatenationFixture3DDim1)
+{
+ RunTest<3, armnn::DataType::QuantisedAsymm8>(
+ 0,
+ { { "inputTensor1", { 0, 1, 2, 3, 4, 5 } },
+ { "inputTensor2", { 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17 } } },
+ { { "outputTensor", { 0, 1, 2, 3, 4, 5,
+ 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17 } } });
+}
+
+struct ConcatenationFixture3DDim2 : ConcatenationFixture
+{
+ ConcatenationFixture3DDim2() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 2, 6]", "[ 1, 2, 9]", "2" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenation3DDim2, ConcatenationFixture3DDim2)
+{
+ RunTest<3, armnn::DataType::QuantisedAsymm8>(
+ 0,
+ { { "inputTensor1", { 0, 1, 2,
+ 3, 4, 5 } },
+ { "inputTensor2", { 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17 } } },
+ { { "outputTensor", { 0, 1, 2, 6, 7, 8, 9, 10, 11,
+ 3, 4, 5, 12, 13, 14, 15, 16, 17 } } });
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index 9d633cfc42..2286f8b6ed 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -16,12 +16,16 @@ namespace armnnUtils
const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
-void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
- const unsigned int& concatAxis, unsigned int inputIndex,
- std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
+ armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis,
+ unsigned int inputIndex,
+ unsigned int& mergeDimOrigin)
{
+ const uint32_t inputRank = concatDescriptor.GetNumDimensions();
+
// double check dimensions of the tensors
- if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
+ if (inputTensorInfo.GetNumDimensions() != inputRank)
{
throw armnn::ParseException(
boost::str(
@@ -29,33 +33,19 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori
"The number of dimensions: %1% for input tensors of the "
"concatenation op should be %2% %3%")
% inputTensorInfo.GetNumDimensions()
- % armnn::MaxNumOfTensorDimensions
+ % inputRank
% CHECK_LOCATION().AsString()));
}
- // if concatenation axis is 3 then need to be permuted
- if (concatAxis == 3)
- {
- inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
- }
-
- for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
- {
- mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
- }
-
- // Concatenation dimension 1 is the only dimension supported in ArmNN
- const unsigned int concatenationDim = 1;
-
- for (unsigned int j = 0; j < concatenationDim; ++j)
+ for (unsigned int j = 0; j < concatAxis; ++j)
{
concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
}
- concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
- mergeDim += mergeDimSizes[concatenationDim];
+ concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
+ mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
- for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
+ for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
{
concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
}
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
index 24369dc521..bcc1e5b2cc 100644
--- a/src/armnnUtils/ParserHelper.hpp
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -10,9 +10,11 @@
namespace armnnUtils
{
-void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
- const unsigned int& concatAxis, unsigned int inputIndex,
- std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim);
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
+ armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis,
+ unsigned int inputIndex,
+ unsigned int& mergeDimOrigin);
/// Creates a tensor info after reducing the dimensions mentioned in axisData.
void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo,