From f176d5af107b8797d9eb74d1699a4e405e4a9a83 Mon Sep 17 00:00:00 2001 From: narpra01 Date: Sun, 18 Nov 2018 20:17:48 +0000 Subject: IVGCVSW-2127- Update HAL Policy for merger * Remove permutation when concat axis is inner most * Add additional parameter to IsMergerSupported as changed in armnn !armnn:151 Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a --- 1.0/HalPolicy.cpp | 37 +++++++++++++++++++++++-------------- ConversionUtils.hpp | 51 +++++++++++++++------------------------------------ 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp index d0bd95bb..719d1a24 100644 --- a/1.0/HalPolicy.cpp +++ b/1.0/HalPolicy.cpp @@ -241,17 +241,22 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } else if (tensorDimensionsAdded == 2) { - outputShape = armnn::TensorShape({1, 1, outputShape[0], outputShape[1]}); + outputShape = armnn::TensorShape({1, 1, outputShape[0]}); } } - // Get the pair of permutations required for the concatenation + // Check if permutations is required and get the pair of permutations required for the concatenation. + // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor. std::pair permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - CreatePermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); + bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); + + if (needPermute) + { + outputShape = armnnUtils::Permuted(outputShape, permutationPair.first); + } - outputShape = armnnUtils::Permuted(outputShape, permutationPair.first); outputInfo.SetShape(outputShape); // this is no-op for identity swizzles, otherwise it replaces both @@ -260,10 +265,11 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo // Create an armnn merger layer descriptor - this will also perform validation on the input shapes armnn::OriginsDescriptor mergerDescriptor; + try { - // The merger descriptor is always created across the only supported concat - // dimension, which is 0 or 1 + // The merger descriptor is always created across the only supported concat dimension + // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor. mergerDescriptor = armnn::CreateMergerDescriptorForConcatenation( inputShapes.begin(), inputShapes.end(), concatDim); @@ -274,7 +280,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } // Validate the output shape is correct given the input shapes based on the - // only valid concat dimension which is 0 or 1 + // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor. if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim)) { return Fail("%s: Error validating the output shape for concat", __func__); @@ -287,6 +293,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo armnn::IsMergerSupported, data.m_Compute, inputTensorInfos, + outputInfo, mergerDescriptor)) { return false; @@ -305,11 +312,14 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo inputHandles[static_cast(i)].Connect(layer->GetInputSlot(i)); } - // Add permutation layer and connect the output to it, the permutation becomes the output layer - armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network, - layer->GetOutputSlot(0), - permutationPair.second); - layer = &deswizzleLayer; + if (needPermute) + { + // Add permutation layer and connect the output to it, the permutation becomes the output layer + armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network, + layer->GetOutputSlot(0), + permutationPair.second); + layer = &deswizzleLayer; + } if (inputsHaveBeenReshaped) { @@ -323,8 +333,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo } else if (tensorDimensionsAdded == 2) { - afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2], - afterConcatInfo.GetShape()[3] })); + afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] })); } layer = &AddReshapeLayer( diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 68ce09d8..c86ad93c 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -390,50 +390,29 @@ void SwizzleInputs(armnn::INetwork& network, } } -void CreatePermutationParameters(const unsigned int numberOfDimensions, - int32_t & concatDimension, - std::pair & permutationPair) +bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions, + int32_t & concatDimension, + std::pair & permutationPair) { + bool needPermute = false; BOOST_ASSERT(numberOfDimensions >= 3); // ArmNN uses Compute Library subtensors to perform concatenation - // This only works when concatenating along dimension 0 or 1 for a 4-D tensor, - // or along dimension 0 for a 3-D tensor. - if (numberOfDimensions == 4) + // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor, + // or along dimension 0 or 2 for a 3-D tensor. + if (numberOfDimensions == 4 && concatDimension == 2) { - if (concatDimension == 3) - { - concatDimension = 1; - permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC); - } - else if (concatDimension == 2) - { - concatDimension = 1; - permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2); - } - else - { - permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - } - + concatDimension = 1; + permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2); + needPermute = true; } - else if (numberOfDimensions == 3) + else if (numberOfDimensions == 3 && concatDimension == 1) { - if (concatDimension == 2) - { - concatDimension = 0; - permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft); - } - else if (concatDimension == 1) - { - concatDimension = 0; - permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); - } - else - { - permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D); - } + concatDimension = 0; + permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); + needPermute = true; } + return needPermute; } } // anonymous namespace -- cgit v1.2.1