diff options
author | narpra01 <narumol.prangnawarat@arm.com> | 2018-11-18 20:17:48 +0000 |
---|---|---|
committer | narpra01 <narumol.prangnawarat@arm.com> | 2018-11-18 20:17:48 +0000 |
commit | f176d5af107b8797d9eb74d1699a4e405e4a9a83 (patch) | |
tree | 33d79a3d55f4d6995e85c02338cc8f22a0dcf4c7 /1.0/HalPolicy.cpp | |
parent | c743412b714a42d2e0ccbcae49698a602a6f3d94 (diff) | |
download | android-nn-driver-f176d5af107b8797d9eb74d1699a4e405e4a9a83.tar.gz |
IVGCVSW-2127- Update HAL Policy for mergerbranches/android-nn-driver_18_11
* Remove permutation when concat axis is inner most
* Add additional parameter to IsMergerSupported as changed in armnn
!armnn:151
Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a
Diffstat (limited to '1.0/HalPolicy.cpp')
-rw-r--r-- | 1.0/HalPolicy.cpp | 37 |
1 files changed, 23 insertions, 14 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index d0bd95bb..719d1a24 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -241,17 +241,22 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } else if (tensorDimensionsAdded == 2) { - outputShape = armnn::TensorShape({1, 1, outputShape[0], outputShape[1]}); + outputShape = armnn::TensorShape({1, 1, outputShape[0]}); } } - // Get the pair of permutations required for the concatenation + // Check if permutations is required and get the pair of permutations required for the concatenation. + // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor. std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - CreatePermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); + bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); + + if (needPermute) + { + outputShape = armnnUtils::Permuted(outputShape, permutationPair.first); + } - outputShape = armnnUtils::Permuted(outputShape, permutationPair.first); outputInfo.SetShape(outputShape); // this is no-op for identity swizzles, otherwise it replaces both @@ -260,10 +265,11 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo // Create an armnn merger layer descriptor - this will also perform validation on the input shapes armnn::OriginsDescriptor mergerDescriptor; + try { - // The merger descriptor is always created across the only supported concat - // dimension, which is 0 or 1 + // The merger descriptor is always created across the only supported concat dimension + // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor. mergerDescriptor = armnn::CreateMergerDescriptorForConcatenation( inputShapes.begin(), inputShapes.end(), concatDim); @@ -274,7 +280,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } // Validate the output shape is correct given the input shapes based on the - // only valid concat dimension which is 0 or 1 + // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor. if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim)) { return Fail("%s: Error validating the output shape for concat", __func__); @@ -287,6 +293,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo armnn::IsMergerSupported, data.m_Compute, inputTensorInfos, + outputInfo, mergerDescriptor)) { return false; @@ -305,11 +312,14 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i)); } - // Add permutation layer and connect the output to it, the permutation becomes the output layer - armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network, - layer->GetOutputSlot(0), - permutationPair.second); - layer = &deswizzleLayer; + if (needPermute) + { + // Add permutation layer and connect the output to it, the permutation becomes the output layer + armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network, + layer->GetOutputSlot(0), + permutationPair.second); + layer = &deswizzleLayer; + } if (inputsHaveBeenReshaped) { @@ -323,8 +333,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } else if (tensorDimensionsAdded == 2) { - afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2], - afterConcatInfo.GetShape()[3] })); + afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] })); } layer = &AddReshapeLayer( |