diff options
author | Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> | 2019-01-25 13:20:39 +0000 |
---|---|---|
committer | Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> | 2019-01-25 13:20:39 +0000 |
commit | 5e9d29802e2cfbb13adc49c2a0ac9ba952dc7650 (patch) | |
tree | 31baaf01ff767159fb6a4e594405b4acf03a6f51 /src/armnnUtils | |
parent | 6e2f60674cbe77c2a1da94ab71e35c298a1924de (diff) | |
download | armnn-5e9d29802e2cfbb13adc49c2a0ac9ba952dc7650.tar.gz |
IVGCVSW-2563 Fix bug in TfLiteParser::ParseConcatenation
Change-Id: I8fbf27b383a821e062f72809cc2e269fcd18851c
Diffstat (limited to 'src/armnnUtils')
-rw-r--r-- | src/armnnUtils/ParserHelper.cpp | 36 | ||||
-rw-r--r-- | src/armnnUtils/ParserHelper.hpp | 8 |
2 files changed, 18 insertions, 26 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index 9d633cfc42..2286f8b6ed 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -16,12 +16,16 @@ namespace armnnUtils const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 }; const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 }; -void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor, - const unsigned int& concatAxis, unsigned int inputIndex, - std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim) +void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, + armnn::OriginsDescriptor& concatDescriptor, + const unsigned int& concatAxis, + unsigned int inputIndex, + unsigned int& mergeDimOrigin) { + const uint32_t inputRank = concatDescriptor.GetNumDimensions(); + // double check dimensions of the tensors - if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions) + if (inputTensorInfo.GetNumDimensions() != inputRank) { throw armnn::ParseException( boost::str( @@ -29,33 +33,19 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori "The number of dimensions: %1% for input tensors of the " "concatenation op should be %2% %3%") % inputTensorInfo.GetNumDimensions() - % armnn::MaxNumOfTensorDimensions + % inputRank % CHECK_LOCATION().AsString())); } - // if concatenation axis is 3 then need to be permuted - if (concatAxis == 3) - { - inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN); - } - - for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim) - { - mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim]; - } - - // Concatenation dimension 1 is the only dimension supported in ArmNN - const unsigned int concatenationDim = 1; - - for (unsigned int j = 0; j < concatenationDim; ++j) + for (unsigned int j = 0; j < concatAxis; ++j) { concatDescriptor.SetViewOriginCoord(inputIndex, j, 0); } - concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim); - mergeDim += mergeDimSizes[concatenationDim]; + concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin); + mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis]; - for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j) + for (unsigned int j = concatAxis + 1; j < inputRank; ++j) { concatDescriptor.SetViewOriginCoord(inputIndex, j, 0); } diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp index 24369dc521..bcc1e5b2cc 100644 --- a/src/armnnUtils/ParserHelper.hpp +++ b/src/armnnUtils/ParserHelper.hpp @@ -10,9 +10,11 @@ namespace armnnUtils { -void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor, - const unsigned int& concatAxis, unsigned int inputIndex, - std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim); +void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, + armnn::OriginsDescriptor& concatDescriptor, + const unsigned int& concatAxis, + unsigned int inputIndex, + unsigned int& mergeDimOrigin); /// Creates a tensor info after reducing the dimensions mentioned in axisData. void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo, |