aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp22
1 files changed, 13 insertions, 9 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 3b01b40f..ebfc43b7 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -579,20 +579,21 @@ void SwizzleInputs(armnn::INetwork& network,
}
}
-bool CheckReshapeSupported(ConversionData& data,
- std::vector<LayerInputHandle>& inputs,
- std::vector<armnn::TensorShape>& inputShapes,
- const armnn::PermutationVector& mapping,
- const armnn::TensorInfo& outputInfo)
+bool TransposeInputTensors(ConversionData& data,
+ std::vector<LayerInputHandle>& inputs,
+ std::vector<armnn::TensorShape>& inputShapes,
+ const armnn::PermutationVector& mapping)
{
if (!mapping.IsEqual(IdentityPermutation4D))
{
+ armnn::TensorInfo outputTransposeInfo;
size_t nInputs = inputs.size();
for (size_t i=0; i<nInputs; ++i)
{
// check permute layer
armnn::TransposeDescriptor transposeDesc;
transposeDesc.m_DimMappings = mapping;
+ outputTransposeInfo = armnnUtils::TransposeTensorShape(inputs[i].GetTensorInfo(), mapping);
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
@@ -600,7 +601,7 @@ bool CheckReshapeSupported(ConversionData& data,
data.m_Backends,
isSupported,
inputs[i].GetTensorInfo(),
- outputInfo,
+ outputTransposeInfo,
transposeDesc);
if (!isSupported)
{
@@ -1985,7 +1986,7 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
// this is no-op for identity swizzles, otherwise it replaces both
// the handles and shapes with the swizzled layer output handles and shapes
- if (!CheckReshapeSupported(data, inputHandles, inputShapes, permutationPair.first, outputInfo))
+ if (!TransposeInputTensors(data, inputHandles, inputShapes, permutationPair.first))
{
return false;
}
@@ -2046,14 +2047,17 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
{
armnn::TransposeDescriptor transposeDesc;
transposeDesc.m_DimMappings = permutationPair.second;
+ armnn::TensorInfo inputTransposeInfo = layer->GetOutputSlot(0).GetTensorInfo();
+ armnn::TensorInfo outputTransposeInfo = armnnUtils::TransposeTensorShape(inputTransposeInfo,
+ permutationPair.second);
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
IsTransposeSupported,
data.m_Backends,
isSupported,
- layer->GetOutputSlot(0).GetTensorInfo(),
- outputInfo,
+ inputTransposeInfo,
+ outputTransposeInfo,
transposeDesc);
if (!isSupported)
{