aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-11-18 20:17:48 +0000
committernarpra01 <narumol.prangnawarat@arm.com>2018-11-18 20:17:48 +0000
commitf176d5af107b8797d9eb74d1699a4e405e4a9a83 (patch)
tree33d79a3d55f4d6995e85c02338cc8f22a0dcf4c7 /ConversionUtils.hpp
parentc743412b714a42d2e0ccbcae49698a602a6f3d94 (diff)
downloadandroid-nn-driver-f176d5af107b8797d9eb74d1699a4e405e4a9a83.tar.gz
IVGCVSW-2127- Update HAL Policy for mergerbranches/android-nn-driver_18_11
* Remove permutation when concat axis is inner most * Add additional parameter to IsMergerSupported as changed in armnn !armnn:151 Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp51
1 files changed, 15 insertions, 36 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 68ce09d8..c86ad93c 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -390,50 +390,29 @@ void SwizzleInputs(armnn::INetwork& network,
}
}
-void CreatePermutationParameters(const unsigned int numberOfDimensions,
- int32_t & concatDimension,
- std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
+bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions,
+ int32_t & concatDimension,
+ std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
{
+ bool needPermute = false;
BOOST_ASSERT(numberOfDimensions >= 3);
// ArmNN uses Compute Library subtensors to perform concatenation
- // This only works when concatenating along dimension 0 or 1 for a 4-D tensor,
- // or along dimension 0 for a 3-D tensor.
- if (numberOfDimensions == 4)
+ // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor,
+ // or along dimension 0 or 2 for a 3-D tensor.
+ if (numberOfDimensions == 4 && concatDimension == 2)
{
- if (concatDimension == 3)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC);
- }
- else if (concatDimension == 2)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- }
-
+ concatDimension = 1;
+ permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
+ needPermute = true;
}
- else if (numberOfDimensions == 3)
+ else if (numberOfDimensions == 3 && concatDimension == 1)
{
- if (concatDimension == 2)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft);
- }
- else if (concatDimension == 1)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D);
- }
+ concatDimension = 0;
+ permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
+ needPermute = true;
}
+ return needPermute;
}
} // anonymous namespace