aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp28
1 files changed, 13 insertions, 15 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 60bd962db7..63fb60382c 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1762,11 +1762,16 @@ void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
}
}
- armnn::IConnectableLayer* layer;
+ node.input_size() == 3 ? desc.m_BiasEnabled = true : desc.m_BiasEnabled = false;
+ armnn::IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, node.name().c_str());
std::vector<std::string> tensorIndexes= {node.input(0), node.input(1)};
auto weightTensor = CreateConstTensor(node.input(1));
+ IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
+ weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
+ weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+
if (node.input_size() == 3)
{
if(!m_TensorsInfo[node.input(2)].isConstant())
@@ -1777,22 +1782,15 @@ void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
CHECK_LOCATION().AsString()));
}
desc.m_BiasEnabled = true;
- tensorIndexes.emplace_back(node.input(2));
auto biasTensor = CreateConstTensor(node.input(2));
- ARMNN_NO_DEPRECATE_WARN_BEGIN
- layer = m_Network->AddConvolution2dLayer(desc,
- weightTensor.first,
- Optional<ConstTensor>(biasTensor.first),
- node.name().c_str());
- }
- else
- {
- layer = m_Network->AddConvolution2dLayer(desc,
- weightTensor.first,
- EmptyOptional(),
- node.name().c_str());
- ARMNN_NO_DEPRECATE_WARN_END
+
+ IConnectableLayer* biasLayer = m_Network->AddConstantLayer(biasTensor.first);
+ biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.first.GetInfo());
+ biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+
+ tensorIndexes.emplace_back(node.input(2));
}
+
ARMNN_ASSERT(layer != nullptr);
auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,