aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2018-11-18 20:17:48 +0000
committernarpra01 <narumol.prangnawarat@arm.com>2018-11-18 20:17:48 +0000
commitf176d5af107b8797d9eb74d1699a4e405e4a9a83 (patch)
tree33d79a3d55f4d6995e85c02338cc8f22a0dcf4c7
parentc743412b714a42d2e0ccbcae49698a602a6f3d94 (diff)
downloadandroid-nn-driver-branches/android-nn-driver_18_11.tar.gz
IVGCVSW-2127- Update HAL Policy for mergerbranches/android-nn-driver_18_11
* Remove permutation when concat axis is inner most * Add additional parameter to IsMergerSupported as changed in armnn !armnn:151 Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a
-rw-r--r--1.0/HalPolicy.cpp37
-rw-r--r--ConversionUtils.hpp51
2 files changed, 38 insertions, 50 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index d0bd95bb..719d1a24 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -241,17 +241,22 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo
}
else if (tensorDimensionsAdded == 2)
{
- outputShape = armnn::TensorShape({1, 1, outputShape[0], outputShape[1]});
+ outputShape = armnn::TensorShape({1, 1, outputShape[0]});
}
}
- // Get the pair of permutations required for the concatenation
+ // Check if permutations is required and get the pair of permutations required for the concatenation.
+ // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor.
std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair =
std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- CreatePermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+ bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+
+ if (needPermute)
+ {
+ outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
+ }
- outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
outputInfo.SetShape(outputShape);
// this is no-op for identity swizzles, otherwise it replaces both
@@ -260,10 +265,11 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo
// Create an armnn merger layer descriptor - this will also perform validation on the input shapes
armnn::OriginsDescriptor mergerDescriptor;
+
try
{
- // The merger descriptor is always created across the only supported concat
- // dimension, which is 0 or 1
+ // The merger descriptor is always created across the only supported concat dimension
+ // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
mergerDescriptor =
armnn::CreateMergerDescriptorForConcatenation(
inputShapes.begin(), inputShapes.end(), concatDim);
@@ -274,7 +280,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo
}
// Validate the output shape is correct given the input shapes based on the
- // only valid concat dimension which is 0 or 1
+ // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
{
return Fail("%s: Error validating the output shape for concat", __func__);
@@ -287,6 +293,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo
armnn::IsMergerSupported,
data.m_Compute,
inputTensorInfos,
+ outputInfo,
mergerDescriptor))
{
return false;
@@ -305,11 +312,14 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo
inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i));
}
- // Add permutation layer and connect the output to it, the permutation becomes the output layer
- armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
- layer->GetOutputSlot(0),
- permutationPair.second);
- layer = &deswizzleLayer;
+ if (needPermute)
+ {
+ // Add permutation layer and connect the output to it, the permutation becomes the output layer
+ armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
+ layer->GetOutputSlot(0),
+ permutationPair.second);
+ layer = &deswizzleLayer;
+ }
if (inputsHaveBeenReshaped)
{
@@ -323,8 +333,7 @@ bool HalPolicy::ConvertConcatenation(const Operation& operation, const Model& mo
}
else if (tensorDimensionsAdded == 2)
{
- afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2],
- afterConcatInfo.GetShape()[3] }));
+ afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] }));
}
layer = &AddReshapeLayer(
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 68ce09d8..c86ad93c 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -390,50 +390,29 @@ void SwizzleInputs(armnn::INetwork& network,
}
}
-void CreatePermutationParameters(const unsigned int numberOfDimensions,
- int32_t & concatDimension,
- std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
+bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions,
+ int32_t & concatDimension,
+ std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
{
+ bool needPermute = false;
BOOST_ASSERT(numberOfDimensions >= 3);
// ArmNN uses Compute Library subtensors to perform concatenation
- // This only works when concatenating along dimension 0 or 1 for a 4-D tensor,
- // or along dimension 0 for a 3-D tensor.
- if (numberOfDimensions == 4)
+ // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor,
+ // or along dimension 0 or 2 for a 3-D tensor.
+ if (numberOfDimensions == 4 && concatDimension == 2)
{
- if (concatDimension == 3)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC);
- }
- else if (concatDimension == 2)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- }
-
+ concatDimension = 1;
+ permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
+ needPermute = true;
}
- else if (numberOfDimensions == 3)
+ else if (numberOfDimensions == 3 && concatDimension == 1)
{
- if (concatDimension == 2)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft);
- }
- else if (concatDimension == 1)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D);
- }
+ concatDimension = 0;
+ permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
+ needPermute = true;
}
+ return needPermute;
}
} // anonymous namespace