diff options
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 5837d3df..f139383e 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -489,6 +489,7 @@ void SanitizeBiasQuantizationScale(armnn::TensorInfo& biasInfo, // 4D Tensor Permutations const armnn::PermutationVector IdentityPermutation4D({ 0U, 1U, 2U, 3U }); +const armnn::PermutationVector IdentityPermutation3D({ 0U, 1U, 2U }); const armnn::PermutationVector SwapDim1And2({ 0U, 2U, 1U, 3U }); // 3D Permutation Vectors @@ -588,7 +589,8 @@ bool TransposeInputTensors(ConversionData& data, std::vector<armnn::TensorShape>& inputShapes, const armnn::PermutationVector& mapping) { - if (!mapping.IsEqual(IdentityPermutation4D)) + // If we have a IdentityPermutation4D or IdentityPermutation3D then we are not permuting + if (!mapping.IsEqual(IdentityPermutation4D) && !mapping.IsEqual(IdentityPermutation3D)) { armnn::TensorInfo outputTransposeInfo; size_t nInputs = inputs.size(); @@ -641,6 +643,12 @@ bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions, permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight); needPermute = true; } + // If the tensor is 3-D and the concat dimension is 2 then we don't need to permute but we do need to change the + // permutation identity to only have 3 dimensions + else if (numberOfDimensions == 3 && concatDimension == 2) + { + permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D); + } return needPermute; } @@ -2134,7 +2142,6 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor. std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D); - bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair); @@ -2232,7 +2239,6 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, inputTransposeInfo, outputTransposeInfo, transposeDesc); - if (!isSupported) { return false; @@ -2255,7 +2261,7 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, if (isDynamicTensor) { // Infer the output shapes of concat if outputs are type 1 dynamic - layer->GetOutputSlot(0).IsTensorInfoSet(); + ARMNN_ASSERT(layer->GetOutputSlot(0).IsTensorInfoSet()); if (!ValidateConcatOutputShape(inputShapes, layer->GetOutputSlot(0).GetTensorInfo().GetShape(), concatDim)) @@ -2266,7 +2272,6 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, } armnn::TensorInfo afterConcatInfo = layer->GetOutputSlot(0).GetTensorInfo(); - // Undo the reshape knowing the amount of dimensions added if (tensorDimensionsAdded == 1) { @@ -2306,7 +2311,6 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, { return false; } - layer = &AddReshapeLayer(*data.m_Network, layer->GetOutputSlot(0), afterConcatInfo); return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, |