aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-16 18:13:39 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-16 18:13:39 +0100
commitcdc495ea61a94ced93e877b62bcca5fa68f52f9b (patch)
tree49b3dfeeb10afe08d5fda22d75c076ac17374f1c /src/armnnOnnxParser/OnnxParser.cpp
parentf106ab745a12a5c773a9c315dcddef0c8bf11225 (diff)
downloadarmnn-cdc495ea61a94ced93e877b62bcca5fa68f52f9b.tar.gz
IVGCVSW-6382 Add Shape operator support to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I3547effcbebf1ebc02d3b20f5db394a26991424d
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp27
1 files changed, 26 insertions, 1 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 49f0271aeb..889c35f391 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -426,7 +426,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser
{ "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
{ "Conv", &OnnxParserImpl::ParseConv },
{ "Add", &OnnxParserImpl::ParseAdd },
- { "Flatten", &OnnxParserImpl::ParseFlatten},
+ { "Flatten", &OnnxParserImpl::ParseFlatten },
+ { "Shape", &OnnxParserImpl::ParseShape }
};
template<typename TypePair, typename Location>
@@ -1653,6 +1654,30 @@ void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node)
AddPoolingLayer(node, desc);
}
+void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
+{
+ CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
+ CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+ // Output must be INT64
+ CHECK_VALID_DATATYPE(node.name(), node.output(0),
+ m_TensorsInfo[node.output(0)].m_dtype,
+ onnx::TensorProto::INT64);
+
+ IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ RegisterInputSlots(layer, {node.input(0)});
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ RegisterOutputSlots(layer, {node.output(0)});
+}
+
void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
{
CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);