aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-23 16:11:17 +0100
committerJim Flynn <jim.flynn@arm.com>2021-09-29 16:24:21 +0000
commitfe6aa2f1ec991bb42356ed4068158395e8c78a7c (patch)
tree4246c8e850a1b66914eec0da5a7b4c613e07543d /src/armnnOnnxParser/OnnxParser.cpp
parentcd20385b5511ffddb025066edc6988b824dfc1c4 (diff)
downloadarmnn-fe6aa2f1ec991bb42356ed4068158395e8c78a7c.tar.gz
IVGCVSW-6382 Add Unsqueeze operator support to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ie0b68b08fc31444c58b0ffc9babdd456bbb51f35
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp54
1 files changed, 54 insertions, 0 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index e70eb64047..91ba52f32c 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -429,6 +429,7 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser
{ "Flatten", &OnnxParserImpl::ParseFlatten },
{ "Shape", &OnnxParserImpl::ParseShape },
{ "Gather", &OnnxParserImpl::ParseGather },
+ { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze }
};
template<typename TypePair, typename Location>
@@ -1834,6 +1835,59 @@ void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
}
}
+void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
+{
+ CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
+ CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
+
+ CHECK_VALID_DATATYPE(node.name(), node.input(0),
+ m_TensorsInfo[node.input(0)].m_dtype,
+ onnx::TensorProto::FLOAT); //input
+
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+ std::vector<uint32_t> dims;
+ if (node.input_size() == 1 && node.attribute_size() > 0)
+ {
+ dims = ReadMandatoryNodeUint32ListAttribute(node, "axes");
+ }
+ else
+ {
+ CHECK_VALID_DATATYPE(node.name(), node.input(1),
+ m_TensorsInfo[node.input(1)].m_dtype,
+ onnx::TensorProto::INT64); //axes
+
+ auto int64Axes = m_TensorsInfo[node.input(1)].m_tensor->int64_data().data();
+ uint numDim = armnn::numeric_cast<uint>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
+
+ for(uint i = 0; i < numDim; i++)
+ {
+ uint32_t uint32Value = CHECKED_NON_NEGATIVE(CHECKED_INT32(int64Axes[i]));
+ dims.push_back(uint32Value);
+ }
+ }
+
+ // Ensure that the axes are sorted
+ std::sort(dims.begin(), dims.end());
+
+ std::vector<unsigned int> targetShape;
+
+ for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
+ {
+ targetShape.push_back(inputShape[i]);
+ }
+
+ for(uint i = 0; i < dims.size(); i++)
+ {
+ targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
+ }
+
+ auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast<unsigned int>(targetShape.size()),
+ targetShape.data()), inputShape, node.output(0));
+ m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
+
+ CreateReshapeLayer(node.input(0), node.output(0), node.name());
+}
+
void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
const std::string& input0,
const std::string& input1)