diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-17 21:08:57 +0100 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2021-09-24 14:17:31 +0000 |
commit | f10b15a8946f39bdf3f60cebc59d2963069eedca (patch) | |
tree | 9cba39db69acad2bd5728cefbad578161e6ba63c /src/armnnOnnxParser/OnnxParser.cpp | |
parent | 4fcc8632aaa64e683d98199659093d1aa99ffb08 (diff) | |
download | armnn-f10b15a8946f39bdf3f60cebc59d2963069eedca.tar.gz |
IVGCVSW-6382 Add Gather operator support to ONNX parser
* Add ParseGather to support Gather operator on ONNX
* Add Support of int64 converted to int32 for constant
* Add OnnxParserTestUtils
* Refactor ValidateTensorShapesFromInputs of GatherLayer
* Unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 115 |
1 files changed, 109 insertions, 6 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 889c35f391..e70eb64047 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -427,7 +427,8 @@ const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParser { "Conv", &OnnxParserImpl::ParseConv }, { "Add", &OnnxParserImpl::ParseAdd }, { "Flatten", &OnnxParserImpl::ParseFlatten }, - { "Shape", &OnnxParserImpl::ParseShape } + { "Shape", &OnnxParserImpl::ParseShape }, + { "Gather", &OnnxParserImpl::ParseGather }, }; template<typename TypePair, typename Location> @@ -533,6 +534,10 @@ OnnxParserImpl::CreateConstTensor(const std::string name, TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; + //ONNX can have Float16 and double constant nodes but ArmNN only supports float32 + CHECK_VALID_DATATYPE(name, onnxTensor.name(), + static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT); + // Makes sure IsConstant flag is set. tensorInfo.SetConstant(); @@ -568,6 +573,65 @@ OnnxParserImpl::CreateConstTensor(const std::string name, } } +std::pair<ConstTensor, std::unique_ptr<int32_t[]>> +OnnxParserImpl::CreateInt64ConstTensor(const std::string name, + armnn::Optional<armnn::PermutationVector&> permutationVector) +{ + TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; + onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; + + CHECK_VALID_DATATYPE(name, onnxTensor.name(), + static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64); + + // Makes sure IsConstant flag is set. + tensorInfo.SetConstant(); + uint numElements = tensorInfo.GetNumElements(); + + // Const tensors requires at least a list of values + if (numElements == 0) + { + throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}", + name, + CHECK_LOCATION().AsString())); + } + + // Copy the value list entries into the destination + if (!onnxTensor.has_raw_data()) + { + auto srcData = onnxTensor.int64_data().data(); + if(numElements != static_cast<uint>(onnxTensor.int64_data_size())) + { + throw ParseException( + fmt::format("The number of data provided ({}) does not match the tensor '{}' number of " + "elements ({}) {}", + onnxTensor.int64_data_size(), + name, + tensorInfo.GetNumElements(), + CHECK_LOCATION().AsString())); + } + + std::vector<int32_t> int32Data; + for(uint i = 0; i < numElements; i++) + { + int32_t int32Value = CHECKED_INT32(srcData[i]); + int32Data.push_back(int32Value); + } + + return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector); + } + else + { + auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str()); + std::vector<int32_t> int32Data; + for(uint i = 0; i < numElements; i++) + { + int32_t int32Value = CHECKED_INT32(srcData[i]); + int32Data.push_back(int32Value); + } + return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector); + } +} + ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile) { FILE* fd = fopen(graphFile, "r"); @@ -1152,7 +1216,14 @@ std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const st void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) { auto armnnTensor = CreateConstTensor(tensorName); + IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo()); + RegisterOutputSlots(layer, {tensorName}); +} +void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName) +{ + auto armnnTensor = CreateInt64ConstTensor(tensorName); IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str()); layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo()); RegisterOutputSlots(layer, {tensorName}); @@ -1370,16 +1441,25 @@ void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) } const onnx::TensorProto& onnxTensor = node.attribute(0).t(); - //ONNX can have Float16 and double constant nodes but ArmNN only supports float32 - CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(), - static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT); - //Register this as a m_ConstParam so we know we can use it as a constant param in future layers. m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor); m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor)); m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()); - CreateConstantLayer(node.output(0), node.name()); + if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT) + { + CreateConstantLayer(node.output(0), node.name()); + } + else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64) + { + CreateInt64ConstantLayer(node.output(0), node.name()); + } + else + { + throw ParseException(fmt::format("Data type not support for Constant node '{}' {}", + node.name(), + CHECK_LOCATION().AsString())); + } } void OnnxParserImpl::ParseConv(const onnx::NodeProto& node) @@ -1622,6 +1702,29 @@ void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node) CreateReshapeLayer(node.input(0), node.output(0), node.name()); } +void OnnxParserImpl::ParseGather(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2); + CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); + + armnn::GatherDescriptor gatherDescriptor; + gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0)); + + IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, { node.input(0), node.input(1) }); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, { node.output(0) }); +} + void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc = Pooling2dDescriptor(); |