From f176d5af107b8797d9eb74d1699a4e405e4a9a83 Mon Sep 17 00:00:00 2001 From: narpra01 Date: Sun, 18 Nov 2018 20:17:48 +0000 Subject: IVGCVSW-2127- Update HAL Policy for merger * Remove permutation when concat axis is inner most * Add additional parameter to IsMergerSupported as changed in armnn !armnn:151 Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a --- ConversionUtils.hpp | 51 +++++++++++++++------------------------------------ 1 file changed, 15 insertions(+), 36 deletions(-) (limited to 'ConversionUtils.hpp') diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 68ce09d8..c86ad93c 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -390,50 +390,29 @@ void SwizzleInputs(armnn::INetwork& network, } } -void CreatePermutationParameters(const unsigned int numberOfDimensions, - int32_t & concatDimension, - std::pair & permutationPair) +bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions, + int32_t & concatDimension, + std::pair & permutationPair) { + bool needPermute = false; BOOST_ASSERT(numberOfDimensions >= 3); // ArmNN uses Compute Library subtensors to perform concatenation - // This only works when concatenating along dimension 0 or 1 for a 4-D tensor, - // or along dimension 0 for a 3-D tensor. - if (numberOfDimensions == 4) + // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor, + // or along dimension 0 or 2 for a 3-D tensor. + if (numberOfDimensions == 4 && concatDimension == 2) { - if (concatDimension == 3) - { - concatDimension = 1; - permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC); - } - else if (concatDimension == 2) - { - concatDimension = 1; - permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2); - } - else - { - permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - } - + concatDimension = 1; + permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2); + needPermute = true; } - else if (numberOfDimensions == 3) + else if (numberOfDimensions == 3 && concatDimension == 1) { - if (concatDimension == 2) - { - concatDimension = 0; - permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft); - } - else if (concatDimension == 1) - { - concatDimension = 0; - permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); - } - else - { - permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D); - } + concatDimension = 0; + permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); + needPermute = true; } + return needPermute; } } // anonymous namespace -- cgit v1.2.1