diff options
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 49f0271aeb..889c35f391 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -426,7 +426,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu }, { "Conv", &OnnxParserImpl::ParseConv }, { "Add", &OnnxParserImpl::ParseAdd }, - { "Flatten", &OnnxParserImpl::ParseFlatten}, + { "Flatten", &OnnxParserImpl::ParseFlatten }, + { "Shape", &OnnxParserImpl::ParseShape } }; template<typename TypePair, typename Location> @@ -1653,6 +1654,30 @@ void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node) AddPoolingLayer(node, desc); } +void OnnxParserImpl::ParseShape(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1); + CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); + + // Output must be INT64 + CHECK_VALID_DATATYPE(node.name(), node.output(0), + m_TensorsInfo[node.output(0)].m_dtype, + onnx::TensorProto::INT64); + + IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, {node.input(0)}); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, {node.output(0)}); +} + void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2); |