From bc3bb62c2d5b881ca7f0b3973a533134196fc802 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 24 Sep 2021 16:08:34 +0100 Subject: IVGCVSW-6382 Add Concat operator support to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: I435723160e9b639a70e0b48ee9d722d306461291 --- src/armnnOnnxParser/OnnxParser.cpp | 62 +++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) (limited to 'src/armnnOnnxParser/OnnxParser.cpp') diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 91ba52f32c..3fcb7ab603 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -166,6 +167,18 @@ void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node, } } +int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node, + const std::string& name) +{ + int attribValue = 0; + ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT, + [&attribValue](const onnx::AttributeProto& attrValue) + { + attribValue = CHECKED_INT32(attrValue.i()); + }); + return attribValue; +} + int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node, const std::string& name, const int64_t defaultValue = 0) @@ -429,7 +442,8 @@ const std::map OnnxParser { "Flatten", &OnnxParserImpl::ParseFlatten }, { "Shape", &OnnxParserImpl::ParseShape }, { "Gather", &OnnxParserImpl::ParseGather }, - { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze } + { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze }, + { "Concat", &OnnxParserImpl::ParseConcat } }; template @@ -1431,6 +1445,52 @@ void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } +void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast(node.output_size()), 1); + + uint32_t numConcatView = static_cast(node.input_size()); + uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions(); + + int axisInt = ReadMandatoryNodeIntAttribute(node, "axis"); + + unsigned int concatDimInput = static_cast( + (static_cast(inputRank) + axisInt) % static_cast(inputRank)); + + OriginsDescriptor concatDescriptor(numConcatView, inputRank); + concatDescriptor.SetConcatAxis(concatDimInput); + + unsigned int mergeDimOrigin = 0; + + std::vector inputShapes; + std::vector tensorIds; + + for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex) + { + std::string nodeName = node.input(static_cast(viewIndex)); + auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info; + inputShapes.push_back(inputTensorInfo.GetShape()); + tensorIds.push_back(nodeName); + + // Set up concatDescriptor view origin + armnnUtils::ProcessConcatInputTensorInfo( + inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin); + } + + IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes); + + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, tensorIds); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, { node.output(0) }); +} + void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.attribute_size()), 1); -- cgit v1.2.1