aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-25 13:20:39 +0000
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-01-25 13:20:39 +0000
commit5e9d29802e2cfbb13adc49c2a0ac9ba952dc7650 (patch)
tree31baaf01ff767159fb6a4e594405b4acf03a6f51 /src/armnnUtils
parent6e2f60674cbe77c2a1da94ab71e35c298a1924de (diff)
downloadarmnn-5e9d29802e2cfbb13adc49c2a0ac9ba952dc7650.tar.gz
IVGCVSW-2563 Fix bug in TfLiteParser::ParseConcatenation
Change-Id: I8fbf27b383a821e062f72809cc2e269fcd18851c
Diffstat (limited to 'src/armnnUtils')
-rw-r--r--src/armnnUtils/ParserHelper.cpp36
-rw-r--r--src/armnnUtils/ParserHelper.hpp8
2 files changed, 18 insertions, 26 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index 9d633cfc42..2286f8b6ed 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -16,12 +16,16 @@ namespace armnnUtils
const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
-void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
- const unsigned int& concatAxis, unsigned int inputIndex,
- std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
+ armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis,
+ unsigned int inputIndex,
+ unsigned int& mergeDimOrigin)
{
+ const uint32_t inputRank = concatDescriptor.GetNumDimensions();
+
// double check dimensions of the tensors
- if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
+ if (inputTensorInfo.GetNumDimensions() != inputRank)
{
throw armnn::ParseException(
boost::str(
@@ -29,33 +33,19 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori
"The number of dimensions: %1% for input tensors of the "
"concatenation op should be %2% %3%")
% inputTensorInfo.GetNumDimensions()
- % armnn::MaxNumOfTensorDimensions
+ % inputRank
% CHECK_LOCATION().AsString()));
}
- // if concatenation axis is 3 then need to be permuted
- if (concatAxis == 3)
- {
- inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
- }
-
- for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
- {
- mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
- }
-
- // Concatenation dimension 1 is the only dimension supported in ArmNN
- const unsigned int concatenationDim = 1;
-
- for (unsigned int j = 0; j < concatenationDim; ++j)
+ for (unsigned int j = 0; j < concatAxis; ++j)
{
concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
}
- concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
- mergeDim += mergeDimSizes[concatenationDim];
+ concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
+ mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
- for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
+ for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
{
concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
}
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
index 24369dc521..bcc1e5b2cc 100644
--- a/src/armnnUtils/ParserHelper.hpp
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -10,9 +10,11 @@
namespace armnnUtils
{
-void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
- const unsigned int& concatAxis, unsigned int inputIndex,
- std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim);
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
+ armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int& concatAxis,
+ unsigned int inputIndex,
+ unsigned int& mergeDimOrigin);
/// Creates a tensor info after reducing the dimensions mentioned in axisData.
void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo,