From fe6aa2f1ec991bb42356ed4068158395e8c78a7c Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 23 Sep 2021 16:11:17 +0100 Subject: IVGCVSW-6382 Add Unsqueeze operator support to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: Ie0b68b08fc31444c58b0ffc9babdd456bbb51f35 --- src/armnnOnnxParser/OnnxParser.cpp | 54 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) (limited to 'src/armnnOnnxParser/OnnxParser.cpp') diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index e70eb64047..91ba52f32c 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -429,6 +429,7 @@ const std::map OnnxParser { "Flatten", &OnnxParserImpl::ParseFlatten }, { "Shape", &OnnxParserImpl::ParseShape }, { "Gather", &OnnxParserImpl::ParseGather }, + { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze } }; template @@ -1834,6 +1835,59 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) } } +void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(armnn::numeric_cast(node.input_size()), 1, 2); + CHECK_VALID_SIZE(armnn::numeric_cast(node.output_size()), 1); + + CHECK_VALID_DATATYPE(node.name(), node.input(0), + m_TensorsInfo[node.input(0)].m_dtype, + onnx::TensorProto::FLOAT); //input + + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + std::vector dims; + if (node.input_size() == 1 && node.attribute_size() > 0) + { + dims = ReadMandatoryNodeUint32ListAttribute(node, "axes"); + } + else + { + CHECK_VALID_DATATYPE(node.name(), node.input(1), + m_TensorsInfo[node.input(1)].m_dtype, + onnx::TensorProto::INT64); //axes + + auto int64Axes = m_TensorsInfo[node.input(1)].m_tensor->int64_data().data(); + uint numDim = armnn::numeric_cast(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size()); + + for(uint i = 0; i < numDim; i++) + { + uint32_t uint32Value = CHECKED_NON_NEGATIVE(CHECKED_INT32(int64Axes[i])); + dims.push_back(uint32Value); + } + } + + // Ensure that the axes are sorted + std::sort(dims.begin(), dims.end()); + + std::vector targetShape; + + for(uint i = 0; i < inputShape.GetNumDimensions(); i++) + { + targetShape.push_back(inputShape[i]); + } + + for(uint i = 0; i < dims.size(); i++) + { + targetShape.insert(targetShape.begin() + armnn::numeric_cast(dims[i]), 1); + } + + auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast(targetShape.size()), + targetShape.data()), inputShape, node.output(0)); + m_TensorsInfo[node.output(0)].m_info = std::make_unique(outInfo); + + CreateReshapeLayer(node.input(0), node.output(0), node.name()); +} + void OnnxParserImpl::PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1) -- cgit v1.2.1