diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-16 18:13:39 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-16 18:13:39 +0100 |
commit | cdc495ea61a94ced93e877b62bcca5fa68f52f9b (patch) | |
tree | 49b3dfeeb10afe08d5fda22d75c076ac17374f1c /src/armnnOnnxParser/OnnxParser.cpp | |
parent | f106ab745a12a5c773a9c315dcddef0c8bf11225 (diff) | |
download | armnn-cdc495ea61a94ced93e877b62bcca5fa68f52f9b.tar.gz |
IVGCVSW-6382 Add Shape operator support to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I3547effcbebf1ebc02d3b20f5db394a26991424d
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 49f0271aeb..889c35f391 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -426,7 +426,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu }, { "Conv", &OnnxParserImpl::ParseConv }, { "Add", &OnnxParserImpl::ParseAdd }, - { "Flatten", &OnnxParserImpl::ParseFlatten}, + { "Flatten", &OnnxParserImpl::ParseFlatten }, + { "Shape", &OnnxParserImpl::ParseShape } }; template<typename TypePair, typename Location> @@ -1653,6 +1654,30 @@ void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node) AddPoolingLayer(node, desc); } +void OnnxParserImpl::ParseShape(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1); + CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); + + // Output must be INT64 + CHECK_VALID_DATATYPE(node.name(), node.output(0), + m_TensorsInfo[node.output(0)].m_dtype, + onnx::TensorProto::INT64); + + IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, {node.input(0)}); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, {node.output(0)}); +} + void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2); |