aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-17 21:08:57 +0100
committerJim Flynn <jim.flynn@arm.com>2021-09-24 14:17:31 +0000
commitf10b15a8946f39bdf3f60cebc59d2963069eedca (patch)
tree9cba39db69acad2bd5728cefbad578161e6ba63c /src/armnnOnnxParser/OnnxParser.cpp
parent4fcc8632aaa64e683d98199659093d1aa99ffb08 (diff)
downloadarmnn-f10b15a8946f39bdf3f60cebc59d2963069eedca.tar.gz
IVGCVSW-6382 Add Gather operator support to ONNX parser
* Add ParseGather to support Gather operator on ONNX * Add Support of int64 converted to int32 for constant * Add OnnxParserTestUtils * Refactor ValidateTensorShapesFromInputs of GatherLayer * Unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp115
1 files changed, 109 insertions, 6 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 889c35f391..e70eb64047 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -427,7 +427,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser
{ "Conv", &OnnxParserImpl::ParseConv },
{ "Add", &OnnxParserImpl::ParseAdd },
{ "Flatten", &OnnxParserImpl::ParseFlatten },
- { "Shape", &OnnxParserImpl::ParseShape }
+ { "Shape", &OnnxParserImpl::ParseShape },
+ { "Gather", &OnnxParserImpl::ParseGather },
};
template<typename TypePair, typename Location>
@@ -533,6 +534,10 @@ OnnxParserImpl::CreateConstTensor(const std::string name,
TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
+ //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
+ CHECK_VALID_DATATYPE(name, onnxTensor.name(),
+ static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
+
// Makes sure IsConstant flag is set.
tensorInfo.SetConstant();
@@ -568,6 +573,65 @@ OnnxParserImpl::CreateConstTensor(const std::string name,
}
}
+std::pair<ConstTensor, std::unique_ptr<int32_t[]>>
+OnnxParserImpl::CreateInt64ConstTensor(const std::string name,
+ armnn::Optional<armnn::PermutationVector&> permutationVector)
+{
+ TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
+ onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
+
+ CHECK_VALID_DATATYPE(name, onnxTensor.name(),
+ static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64);
+
+ // Makes sure IsConstant flag is set.
+ tensorInfo.SetConstant();
+ uint numElements = tensorInfo.GetNumElements();
+
+ // Const tensors requires at least a list of values
+ if (numElements == 0)
+ {
+ throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
+ name,
+ CHECK_LOCATION().AsString()));
+ }
+
+ // Copy the value list entries into the destination
+ if (!onnxTensor.has_raw_data())
+ {
+ auto srcData = onnxTensor.int64_data().data();
+ if(numElements != static_cast<uint>(onnxTensor.int64_data_size()))
+ {
+ throw ParseException(
+ fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
+ "elements ({}) {}",
+ onnxTensor.int64_data_size(),
+ name,
+ tensorInfo.GetNumElements(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ std::vector<int32_t> int32Data;
+ for(uint i = 0; i < numElements; i++)
+ {
+ int32_t int32Value = CHECKED_INT32(srcData[i]);
+ int32Data.push_back(int32Value);
+ }
+
+ return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
+ }
+ else
+ {
+ auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str());
+ std::vector<int32_t> int32Data;
+ for(uint i = 0; i < numElements; i++)
+ {
+ int32_t int32Value = CHECKED_INT32(srcData[i]);
+ int32Data.push_back(int32Value);
+ }
+ return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
+ }
+}
+
ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
{
FILE* fd = fopen(graphFile, "r");
@@ -1152,7 +1216,14 @@ std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const st
void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
{
auto armnnTensor = CreateConstTensor(tensorName);
+ IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
+ RegisterOutputSlots(layer, {tensorName});
+}
+void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName)
+{
+ auto armnnTensor = CreateInt64ConstTensor(tensorName);
IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
RegisterOutputSlots(layer, {tensorName});
@@ -1370,16 +1441,25 @@ void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
}
const onnx::TensorProto& onnxTensor = node.attribute(0).t();
- //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
- CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(),
- static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
-
//Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
- CreateConstantLayer(node.output(0), node.name());
+ if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT)
+ {
+ CreateConstantLayer(node.output(0), node.name());
+ }
+ else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64)
+ {
+ CreateInt64ConstantLayer(node.output(0), node.name());
+ }
+ else
+ {
+ throw ParseException(fmt::format("Data type not support for Constant node '{}' {}",
+ node.name(),
+ CHECK_LOCATION().AsString()));
+ }
}
void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
@@ -1622,6 +1702,29 @@ void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node)
CreateReshapeLayer(node.input(0), node.output(0), node.name());
}
+void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
+{
+ CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
+ CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+ armnn::GatherDescriptor gatherDescriptor;
+ gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0));
+
+ IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+ TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape });
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ RegisterInputSlots(layer, { node.input(0), node.input(1) });
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ RegisterOutputSlots(layer, { node.output(0) });
+}
+
void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
{
Pooling2dDescriptor desc = Pooling2dDescriptor();