diff options
author | narpra01 <narumol.prangnawarat@arm.com> | 2018-11-18 20:17:48 +0000 |
---|---|---|
committer | narpra01 <narumol.prangnawarat@arm.com> | 2018-11-18 20:17:48 +0000 |
commit | f176d5af107b8797d9eb74d1699a4e405e4a9a83 (patch) | |
tree | 33d79a3d55f4d6995e85c02338cc8f22a0dcf4c7 /ConversionUtils.hpp | |
parent | c743412b714a42d2e0ccbcae49698a602a6f3d94 (diff) | |
download | android-nn-driver-f176d5af107b8797d9eb74d1699a4e405e4a9a83.tar.gz |
IVGCVSW-2127- Update HAL Policy for mergerbranches/android-nn-driver_18_11
* Remove permutation when concat axis is inner most
* Add additional parameter to IsMergerSupported as changed in armnn
!armnn:151
Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 51 |
1 files changed, 15 insertions, 36 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 68ce09d8..c86ad93c 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -390,50 +390,29 @@ void SwizzleInputs(armnn::INetwork& network, } } -void CreatePermutationParameters(const unsigned int numberOfDimensions, - int32_t & concatDimension, - std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair) +bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions, + int32_t & concatDimension, + std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair) { + bool needPermute = false; BOOST_ASSERT(numberOfDimensions >= 3); // ArmNN uses Compute Library subtensors to perform concatenation - // This only works when concatenating along dimension 0 or 1 for a 4-D tensor, - // or along dimension 0 for a 3-D tensor. - if (numberOfDimensions == 4) + // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor, + // or along dimension 0 or 2 for a 3-D tensor. + if (numberOfDimensions == 4 && concatDimension == 2) { - if (concatDimension == 3) - { - concatDimension = 1; - permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC); - } - else if (concatDimension == 2) - { - concatDimension = 1; - permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2); - } - else - { - permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - } - + concatDimension = 1; + permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2); + needPermute = true; } - else if (numberOfDimensions == 3) + else if (numberOfDimensions == 3 && concatDimension == 1) { - if (concatDimension == 2) - { - concatDimension = 0; - permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft); - } - else if (concatDimension == 1) - { - concatDimension = 0; - permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); - } - else - { - permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D); - } + concatDimension = 0; + permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); + needPermute = true; } + return needPermute; } } // anonymous namespace |