diff options
-rw-r--r-- | ConversionUtils.hpp | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 3b01b40f..ebfc43b7 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -579,20 +579,21 @@ void SwizzleInputs(armnn::INetwork& network, } } -bool CheckReshapeSupported(ConversionData& data, - std::vector<LayerInputHandle>& inputs, - std::vector<armnn::TensorShape>& inputShapes, - const armnn::PermutationVector& mapping, - const armnn::TensorInfo& outputInfo) +bool TransposeInputTensors(ConversionData& data, + std::vector<LayerInputHandle>& inputs, + std::vector<armnn::TensorShape>& inputShapes, + const armnn::PermutationVector& mapping) { if (!mapping.IsEqual(IdentityPermutation4D)) { + armnn::TensorInfo outputTransposeInfo; size_t nInputs = inputs.size(); for (size_t i=0; i<nInputs; ++i) { // check permute layer armnn::TransposeDescriptor transposeDesc; transposeDesc.m_DimMappings = mapping; + outputTransposeInfo = armnnUtils::TransposeTensorShape(inputs[i].GetTensorInfo(), mapping); bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, @@ -600,7 +601,7 @@ bool CheckReshapeSupported(ConversionData& data, data.m_Backends, isSupported, inputs[i].GetTensorInfo(), - outputInfo, + outputTransposeInfo, transposeDesc); if (!isSupported) { @@ -1985,7 +1986,7 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, // this is no-op for identity swizzles, otherwise it replaces both // the handles and shapes with the swizzled layer output handles and shapes - if (!CheckReshapeSupported(data, inputHandles, inputShapes, permutationPair.first, outputInfo)) + if (!TransposeInputTensors(data, inputHandles, inputShapes, permutationPair.first)) { return false; } @@ -2046,14 +2047,17 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, { armnn::TransposeDescriptor transposeDesc; transposeDesc.m_DimMappings = permutationPair.second; + armnn::TensorInfo inputTransposeInfo = layer->GetOutputSlot(0).GetTensorInfo(); + armnn::TensorInfo outputTransposeInfo = armnnUtils::TransposeTensorShape(inputTransposeInfo, + permutationPair.second); bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsTransposeSupported, data.m_Backends, isSupported, - layer->GetOutputSlot(0).GetTensorInfo(), - outputInfo, + inputTransposeInfo, + outputTransposeInfo, transposeDesc); if (!isSupported) { |