diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-24 16:08:34 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-05 15:33:24 +0100 |
commit | bc3bb62c2d5b881ca7f0b3973a533134196fc802 (patch) | |
tree | 365bb5c0a0c63aaf518fd46b4ddc5634521a5571 /src/armnnOnnxParser/OnnxParser.cpp | |
parent | 1b2654fb799c3d25ffcef4d31b5d026d359e2f8f (diff) | |
download | armnn-bc3bb62c2d5b881ca7f0b3973a533134196fc802.tar.gz |
IVGCVSW-6382 Add Concat operator support to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I435723160e9b639a70e0b48ee9d722d306461291
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 62 |
1 files changed, 61 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 91ba52f32c..3fcb7ab603 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -9,6 +9,7 @@ #include <armnn/Descriptors.hpp> #include <armnn/utility/Assert.hpp> #include <armnn/utility/NumericCast.hpp> +#include <ParserHelper.hpp> #include <VerificationHelpers.hpp> #include <fmt/format.h> @@ -166,6 +167,18 @@ void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node, } } +int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node, + const std::string& name) +{ + int attribValue = 0; + ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT, + [&attribValue](const onnx::AttributeProto& attrValue) + { + attribValue = CHECKED_INT32(attrValue.i()); + }); + return attribValue; +} + int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node, const std::string& name, const int64_t defaultValue = 0) @@ -429,7 +442,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser { "Flatten", &OnnxParserImpl::ParseFlatten }, { "Shape", &OnnxParserImpl::ParseShape }, { "Gather", &OnnxParserImpl::ParseGather }, - { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze } + { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze }, + { "Concat", &OnnxParserImpl::ParseConcat } }; template<typename TypePair, typename Location> @@ -1431,6 +1445,52 @@ void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } +void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); + + uint32_t numConcatView = static_cast<uint32_t>(node.input_size()); + uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions(); + + int axisInt = ReadMandatoryNodeIntAttribute(node, "axis"); + + unsigned int concatDimInput = static_cast<unsigned int>( + (static_cast<int>(inputRank) + axisInt) % static_cast<int>(inputRank)); + + OriginsDescriptor concatDescriptor(numConcatView, inputRank); + concatDescriptor.SetConcatAxis(concatDimInput); + + unsigned int mergeDimOrigin = 0; + + std::vector<TensorShape> inputShapes; + std::vector<std::string> tensorIds; + + for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex) + { + std::string nodeName = node.input(static_cast<int>(viewIndex)); + auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info; + inputShapes.push_back(inputTensorInfo.GetShape()); + tensorIds.push_back(nodeName); + + // Set up concatDescriptor view origin + armnnUtils::ProcessConcatInputTensorInfo( + inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin); + } + + IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes); + + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, tensorIds); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, { node.output(0) }); +} + void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1); |