aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-24 16:08:34 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-05 15:33:24 +0100
commitbc3bb62c2d5b881ca7f0b3973a533134196fc802 (patch)
tree365bb5c0a0c63aaf518fd46b4ddc5634521a5571 /src/armnnOnnxParser/OnnxParser.cpp
parent1b2654fb799c3d25ffcef4d31b5d026d359e2f8f (diff)
downloadarmnn-bc3bb62c2d5b881ca7f0b3973a533134196fc802.tar.gz
IVGCVSW-6382 Add Concat operator support to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I435723160e9b639a70e0b48ee9d722d306461291
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp62
1 files changed, 61 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 91ba52f32c..3fcb7ab603 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -9,6 +9,7 @@
#include <armnn/Descriptors.hpp>
#include <armnn/utility/Assert.hpp>
#include <armnn/utility/NumericCast.hpp>
+#include <ParserHelper.hpp>
#include <VerificationHelpers.hpp>
#include <fmt/format.h>
@@ -166,6 +167,18 @@ void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node,
}
}
+int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node,
+ const std::string& name)
+{
+ int attribValue = 0;
+ ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
+ [&attribValue](const onnx::AttributeProto& attrValue)
+ {
+ attribValue = CHECKED_INT32(attrValue.i());
+ });
+ return attribValue;
+}
+
int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
const std::string& name,
const int64_t defaultValue = 0)
@@ -429,7 +442,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser
{ "Flatten", &OnnxParserImpl::ParseFlatten },
{ "Shape", &OnnxParserImpl::ParseShape },
{ "Gather", &OnnxParserImpl::ParseGather },
- { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze }
+ { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze },
+ { "Concat", &OnnxParserImpl::ParseConcat }
};
template<typename TypePair, typename Location>
@@ -1431,6 +1445,52 @@ void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node)
RegisterOutputSlots(layer, {node.output(0)});
}
+void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node)
+{
+ CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+ uint32_t numConcatView = static_cast<uint32_t>(node.input_size());
+ uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions();
+
+ int axisInt = ReadMandatoryNodeIntAttribute(node, "axis");
+
+ unsigned int concatDimInput = static_cast<unsigned int>(
+ (static_cast<int>(inputRank) + axisInt) % static_cast<int>(inputRank));
+
+ OriginsDescriptor concatDescriptor(numConcatView, inputRank);
+ concatDescriptor.SetConcatAxis(concatDimInput);
+
+ unsigned int mergeDimOrigin = 0;
+
+ std::vector<TensorShape> inputShapes;
+ std::vector<std::string> tensorIds;
+
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ std::string nodeName = node.input(static_cast<int>(viewIndex));
+ auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info;
+ inputShapes.push_back(inputTensorInfo.GetShape());
+ tensorIds.push_back(nodeName);
+
+ // Set up concatDescriptor view origin
+ armnnUtils::ProcessConcatInputTensorInfo(
+ inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
+ }
+
+ IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes);
+
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ RegisterInputSlots(layer, tensorIds);
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ RegisterOutputSlots(layer, { node.output(0) });
+}
+
void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
{
CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);