aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp51
1 files changed, 15 insertions, 36 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 68ce09d8..c86ad93c 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -390,50 +390,29 @@ void SwizzleInputs(armnn::INetwork& network,
}
}
-void CreatePermutationParameters(const unsigned int numberOfDimensions,
- int32_t & concatDimension,
- std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
+bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions,
+ int32_t & concatDimension,
+ std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
{
+ bool needPermute = false;
BOOST_ASSERT(numberOfDimensions >= 3);
// ArmNN uses Compute Library subtensors to perform concatenation
- // This only works when concatenating along dimension 0 or 1 for a 4-D tensor,
- // or along dimension 0 for a 3-D tensor.
- if (numberOfDimensions == 4)
+ // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor,
+ // or along dimension 0 or 2 for a 3-D tensor.
+ if (numberOfDimensions == 4 && concatDimension == 2)
{
- if (concatDimension == 3)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC);
- }
- else if (concatDimension == 2)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- }
-
+ concatDimension = 1;
+ permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
+ needPermute = true;
}
- else if (numberOfDimensions == 3)
+ else if (numberOfDimensions == 3 && concatDimension == 1)
{
- if (concatDimension == 2)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft);
- }
- else if (concatDimension == 1)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D);
- }
+ concatDimension = 0;
+ permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
+ needPermute = true;
}
+ return needPermute;
}
} // anonymous namespace