From f176d5af107b8797d9eb74d1699a4e405e4a9a83 Mon Sep 17 00:00:00 2001 From: narpra01 Date: Sun, 18 Nov 2018 20:17:48 +0000 Subject: IVGCVSW-2127- Update HAL Policy for merger * Remove permutation when concat axis is inner most * Add additional parameter to IsMergerSupported as changed in armnn !armnn:151 Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a --- 1.0/HalPolicy.cpp | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) (limited to '1.0') diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index d0bd95bb..719d1a24 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -241,17 +241,22 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } else if (tensorDimensionsAdded == 2) { - outputShape = armnn::TensorShape({1, 1, outputShape[0], outputShape[1]}); + outputShape = armnn::TensorShape({1, 1, outputShape[0]}); } } - // Get the pair of permutations required for the concatenation + // Check if permutations is required and get the pair of permutations required for the concatenation. + // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor. std::pair permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - CreatePermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); + bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); + + if (needPermute) + { + outputShape = armnnUtils::Permuted(outputShape, permutationPair.first); + } - outputShape = armnnUtils::Permuted(outputShape, permutationPair.first); outputInfo.SetShape(outputShape); // this is no-op for identity swizzles, otherwise it replaces both @@ -260,10 +265,11 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo // Create an armnn merger layer descriptor - this will also perform validation on the input shapes armnn::OriginsDescriptor mergerDescriptor; + try { - // The merger descriptor is always created across the only supported concat - // dimension, which is 0 or 1 + // The merger descriptor is always created across the only supported concat dimension + // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor. mergerDescriptor = armnn::CreateMergerDescriptorForConcatenation( inputShapes.begin(), inputShapes.end(), concatDim); @@ -274,7 +280,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } // Validate the output shape is correct given the input shapes based on the - // only valid concat dimension which is 0 or 1 + // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor. if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim)) { return Fail("%s: Error validating the output shape for concat", __func__); @@ -287,6 +293,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo armnn::IsMergerSupported, data.m_Compute, inputTensorInfos, + outputInfo, mergerDescriptor)) { return false; @@ -305,11 +312,14 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo inputHandles[static_cast(i)].Connect(layer->GetInputSlot(i)); } - // Add permutation layer and connect the output to it, the permutation becomes the output layer - armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network, - layer->GetOutputSlot(0), - permutationPair.second); - layer = &deswizzleLayer; + if (needPermute) + { + // Add permutation layer and connect the output to it, the permutation becomes the output layer + armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network, + layer->GetOutputSlot(0), + permutationPair.second); + layer = &deswizzleLayer; + } if (inputsHaveBeenReshaped) { @@ -323,8 +333,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } else if (tensorDimensionsAdded == 2) { - afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2], - afterConcatInfo.GetShape()[3] })); + afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] })); } layer = &AddReshapeLayer( -- cgit v1.2.1