aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp23
1 files changed, 16 insertions, 7 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 4eaf63653b..60bd962db7 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1043,15 +1043,24 @@ void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node,
desc.m_BiasEnabled = convDesc.m_BiasEnabled;
armnn::IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, node.name().c_str());
- std::vector<std::string> tensorIndexes= {node.input(0), node.input(1)};
-
- // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs dephtwise weights layout [1,H,W,O]
- armnn::PermutationVector perVec {3,0,1,2};
- auto weightTensor = CreateConstTensor(node.input(1), perVec);
+ std::string permuteStr = "permute_" + node.input(1);
+ std::vector<std::string> tensorIndexes= {node.input(0), permuteStr};
+ auto weightTensor = CreateConstTensor(node.input(1));
IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
+
+ // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs depthwise weights layout [1,H,W,O]
+ armnn::PermutationVector perVec {3, 0, 1, 2};
+ TensorInfo weightsPermuted = armnnUtils::Permuted(weightTensor.first.GetInfo(), perVec);
+
+ // Inserts NewLayer so layers don't need to be re-sorted.
+ IConnectableLayer* permuteLayer = m_Network->AddPermuteLayer(PermuteDescriptor(perVec),
+ "permute_layer");
+ permuteLayer->GetOutputSlot(0).SetTensorInfo(weightsPermuted);
+ permuteLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+
weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
- weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+ weightsLayer->GetOutputSlot(0).Connect(permuteLayer->GetInputSlot(0u));
if (node.input_size() == 3)
{
@@ -1076,7 +1085,7 @@ void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node,
auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
{ m_TensorsInfo[node.input(0)].m_info->GetShape(),
- weightTensor.first.GetInfo().GetShape() });
+ weightsPermuted.GetShape() });
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);