From bc3bb62c2d5b881ca7f0b3973a533134196fc802 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 24 Sep 2021 16:08:34 +0100 Subject: IVGCVSW-6382 Add Concat operator support to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: I435723160e9b639a70e0b48ee9d722d306461291 --- CMakeLists.txt | 1 + docs/01_01_parsers.dox | 3 + src/armnnOnnxParser/OnnxParser.cpp | 62 +++++++- src/armnnOnnxParser/OnnxParser.hpp | 1 + src/armnnOnnxParser/test/Concat.cpp | 274 ++++++++++++++++++++++++++++++++++++ 5 files changed, 340 insertions(+), 1 deletion(-) create mode 100644 src/armnnOnnxParser/test/Concat.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index bebee7ffb8..ebea9ae751 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -756,6 +756,7 @@ if(BUILD_UNIT_TESTS) src/armnnOnnxParser/test/Addition.cpp src/armnnOnnxParser/test/BatchNorm.cpp src/armnnOnnxParser/test/Clip.cpp + src/armnnOnnxParser/test/Concat.cpp src/armnnOnnxParser/test/Const.cpp src/armnnOnnxParser/test/Constructor.cpp src/armnnOnnxParser/test/Conv2D.cpp diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox index 689c062a9a..244b21ead5 100644 --- a/docs/01_01_parsers.dox +++ b/docs/01_01_parsers.dox @@ -40,6 +40,9 @@ The Arm NN SDK ONNX parser currently only supports fp32 operators. - AveragePool - See the ONNX [AveragePool documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#AveragePool) for more information. +- Concat + - See the ONNX [Concat documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Concat) for more information. + - Constant - See the ONNX [Constant documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Constant) for more information. diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 91ba52f32c..3fcb7ab603 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -166,6 +167,18 @@ void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node, } } +int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node, + const std::string& name) +{ + int attribValue = 0; + ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT, + [&attribValue](const onnx::AttributeProto& attrValue) + { + attribValue = CHECKED_INT32(attrValue.i()); + }); + return attribValue; +} + int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node, const std::string& name, const int64_t defaultValue = 0) @@ -429,7 +442,8 @@ const std::map OnnxParser { "Flatten", &OnnxParserImpl::ParseFlatten }, { "Shape", &OnnxParserImpl::ParseShape }, { "Gather", &OnnxParserImpl::ParseGather }, - { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze } + { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze }, + { "Concat", &OnnxParserImpl::ParseConcat } }; template @@ -1431,6 +1445,52 @@ void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } +void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast(node.output_size()), 1); + + uint32_t numConcatView = static_cast(node.input_size()); + uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions(); + + int axisInt = ReadMandatoryNodeIntAttribute(node, "axis"); + + unsigned int concatDimInput = static_cast( + (static_cast(inputRank) + axisInt) % static_cast(inputRank)); + + OriginsDescriptor concatDescriptor(numConcatView, inputRank); + concatDescriptor.SetConcatAxis(concatDimInput); + + unsigned int mergeDimOrigin = 0; + + std::vector inputShapes; + std::vector tensorIds; + + for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex) + { + std::string nodeName = node.input(static_cast(viewIndex)); + auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info; + inputShapes.push_back(inputTensorInfo.GetShape()); + tensorIds.push_back(nodeName); + + // Set up concatDescriptor view origin + armnnUtils::ProcessConcatInputTensorInfo( + inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin); + } + + IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes); + + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, tensorIds); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, { node.output(0) }); +} + void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.attribute_size()), 1); diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index 196e903ed5..6a0fad0ec2 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -113,6 +113,7 @@ private: void ParseAdd(const onnx::NodeProto& nodeProto); void ParseAveragePool(const onnx::NodeProto& nodeProto); void ParseBatchNormalization(const onnx::NodeProto& node); + void ParseConcat(const onnx::NodeProto& nodeProto); void ParseConstant(const onnx::NodeProto& nodeProto); void ParseConv(const onnx::NodeProto& nodeProto); void ParseFlatten(const onnx::NodeProto& node); diff --git a/src/armnnOnnxParser/test/Concat.cpp b/src/armnnOnnxParser/test/Concat.cpp new file mode 100644 index 0000000000..85ebcc307b --- /dev/null +++ b/src/armnnOnnxParser/test/Concat.cpp @@ -0,0 +1,274 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "armnnOnnxParser/IOnnxParser.hpp" +#include "ParserPrototxtFixture.hpp" +#include "OnnxParserTestUtils.hpp" + +TEST_SUITE("OnnxParser_Concat") +{ + +struct ConcatFixture : public armnnUtils::ParserPrototxtFixture +{ + ConcatFixture(const std::string& axis, + const std::vector& input0Shape, + const std::vector& input1Shape, + const std::vector& outputShape) + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + input: "Input0" + input: "Input1" + output: "Output" + op_type: "Concat" + attribute { + name: "axis" + i: )" + axis + R"( + type: INT + } + } + name: "concat-model" + input { + name: "Input0" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(input0Shape) + R"( + } + } + } + } + input { + name: "Input1" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(input1Shape) + R"( + } + } + } + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + })"; + Setup(); + } +}; + +struct ConcatAxis0Fixture : ConcatFixture +{ + ConcatAxis0Fixture() : ConcatFixture("0", { 1, 3, 2, 5 }, { 1, 3, 2, 5 }, { 2, 3, 2, 5 }) {} +}; + +struct ConcatAxis1Fixture : ConcatFixture +{ + ConcatAxis1Fixture() : ConcatFixture("1", { 2, 2, 1, 3 }, { 2, 1, 1, 3 }, { 2, 3, 1, 3 }) {} +}; + +struct ConcatAxis2Fixture : ConcatFixture +{ + ConcatAxis2Fixture() : ConcatFixture("2", { 2, 3, 1, 1 }, { 2, 3, 2, 1 }, { 2, 3, 3, 1 }) {} +}; + +struct ConcatAxis3Fixture : ConcatFixture +{ + ConcatAxis3Fixture() : ConcatFixture("3", { 1, 3, 2, 2 }, { 1, 3, 2, 2 }, { 1, 3, 2, 4 }) {} +}; + +struct ConcatNegativeAxisFixture : ConcatFixture +{ + ConcatNegativeAxisFixture() : ConcatFixture("-1", { 1, 2, 5 }, { 1, 2, 3 }, { 1, 2, 8 }) {} +}; + +TEST_CASE_FIXTURE(ConcatAxis0Fixture, "ConcatAxis0Test") +{ + RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, + 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}, + {"Input1", { 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, + 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, + 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, + 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}}); +} + +TEST_CASE_FIXTURE(ConcatAxis1Fixture, "ConcatAxis1est") +{ + RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}, + {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + 13.0f, 14.0f, 15.0f, + 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + 16.0f, 17.0f, 18.0f }}}); +} + +TEST_CASE_FIXTURE(ConcatAxis2Fixture, "ConcatAxis2Test") +{ + RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}, + {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}}, + {{"Output", { 1.0f, 7.0f, 8.0f, + 2.0f, 9.0f, 10.0f, + 3.0f, 11.0f, 12.0f, + 4.0f, 13.0f, 14.0f, + 5.0f, 15.0f, 16.0f, + 6.0f, 17.0f, 18.0f }}}); +} + +TEST_CASE_FIXTURE(ConcatAxis3Fixture, "ConcatAxis3Test") +{ + RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}, + {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}}, + {{"Output", { 1.0f, 2.0f, 13.0f, 14.0f, + 3.0f, 4.0f, 15.0f, 16.0f, + 5.0f, 6.0f, 17.0f, 18.0f, + 7.0f, 8.0f, 19.0f, 20.0f, + 9.0f, 10.0f, 21.0f, 22.0f, + 11.0f, 12.0f, 23.0f, 24.0f }}}); +} + +TEST_CASE_FIXTURE(ConcatNegativeAxisFixture, "ConcatNegativeAxisTest") +{ + RunTest<3, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}, + {"Input1", { 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f }}}, + {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f, 12.0f, 13.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 14.0f, 15.0f, 16.0f }}}); +} + +struct ConcatMultipleInputsFixture : public armnnUtils::ParserPrototxtFixture +{ + ConcatMultipleInputsFixture() + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + input: "Input0" + input: "Input1" + input: "Input2" + output: "Output" + op_type: "Concat" + attribute { + name: "axis" + i: 1 + type: INT + } + } + name: "concat-model" + input { + name: "Input0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "Input1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + input { + name: "Input2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 1 + } + } + } + } + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 6 + } + } + } + } + } + })"; + Setup(); + } +}; + +TEST_CASE_FIXTURE(ConcatMultipleInputsFixture, "ConcatMultipleInputsTest") +{ + RunTest<2, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}, + {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f }}, + {"Input2", { 16.0f, 17.0f, 18.0f }}}, + {{"Output", { 1.0f, 2.0f, 7.0f, 8.0f, 9.0f, 16.0f, + 3.0f, 4.0f, 10.0f, 11.0f, 12.0f, 17.0f, + 5.0f, 6.0f, 13.0f, 14.0f, 15.0f, 18.0f }}}); +} + +} -- cgit v1.2.1