From f10b15a8946f39bdf3f60cebc59d2963069eedca Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 17 Sep 2021 21:08:57 +0100 Subject: IVGCVSW-6382 Add Gather operator support to ONNX parser * Add ParseGather to support Gather operator on ONNX * Add Support of int64 converted to int32 for constant * Add OnnxParserTestUtils * Refactor ValidateTensorShapesFromInputs of GatherLayer * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7 --- src/armnnOnnxParser/OnnxParser.cpp | 115 +++++++++++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 6 deletions(-) (limited to 'src/armnnOnnxParser/OnnxParser.cpp') diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 889c35f391..e70eb64047 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -427,7 +427,8 @@ const std::map OnnxParser { "Conv", &OnnxParserImpl::ParseConv }, { "Add", &OnnxParserImpl::ParseAdd }, { "Flatten", &OnnxParserImpl::ParseFlatten }, - { "Shape", &OnnxParserImpl::ParseShape } + { "Shape", &OnnxParserImpl::ParseShape }, + { "Gather", &OnnxParserImpl::ParseGather }, }; template @@ -533,6 +534,10 @@ OnnxParserImpl::CreateConstTensor(const std::string name, TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; + //ONNX can have Float16 and double constant nodes but ArmNN only supports float32 + CHECK_VALID_DATATYPE(name, onnxTensor.name(), + static_cast(onnxTensor.data_type()), onnx::TensorProto::FLOAT); + // Makes sure IsConstant flag is set. tensorInfo.SetConstant(); @@ -568,6 +573,65 @@ OnnxParserImpl::CreateConstTensor(const std::string name, } } +std::pair> +OnnxParserImpl::CreateInt64ConstTensor(const std::string name, + armnn::Optional permutationVector) +{ + TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; + onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; + + CHECK_VALID_DATATYPE(name, onnxTensor.name(), + static_cast(onnxTensor.data_type()), onnx::TensorProto::INT64); + + // Makes sure IsConstant flag is set. + tensorInfo.SetConstant(); + uint numElements = tensorInfo.GetNumElements(); + + // Const tensors requires at least a list of values + if (numElements == 0) + { + throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}", + name, + CHECK_LOCATION().AsString())); + } + + // Copy the value list entries into the destination + if (!onnxTensor.has_raw_data()) + { + auto srcData = onnxTensor.int64_data().data(); + if(numElements != static_cast(onnxTensor.int64_data_size())) + { + throw ParseException( + fmt::format("The number of data provided ({}) does not match the tensor '{}' number of " + "elements ({}) {}", + onnxTensor.int64_data_size(), + name, + tensorInfo.GetNumElements(), + CHECK_LOCATION().AsString())); + } + + std::vector int32Data; + for(uint i = 0; i < numElements; i++) + { + int32_t int32Value = CHECKED_INT32(srcData[i]); + int32Data.push_back(int32Value); + } + + return CreateConstTensorImpl(int32Data.data(), tensorInfo, permutationVector); + } + else + { + auto srcData = reinterpret_cast(onnxTensor.raw_data().c_str()); + std::vector int32Data; + for(uint i = 0; i < numElements; i++) + { + int32_t int32Value = CHECKED_INT32(srcData[i]); + int32Data.push_back(int32Value); + } + return CreateConstTensorImpl(int32Data.data(), tensorInfo, permutationVector); + } +} + ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile) { FILE* fd = fopen(graphFile, "r"); @@ -1152,7 +1216,14 @@ std::pair OnnxParserImpl::AddPrepareBroadcast(const st void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) { auto armnnTensor = CreateConstTensor(tensorName); + IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo()); + RegisterOutputSlots(layer, {tensorName}); +} +void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName) +{ + auto armnnTensor = CreateInt64ConstTensor(tensorName); IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str()); layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo()); RegisterOutputSlots(layer, {tensorName}); @@ -1370,16 +1441,25 @@ void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) } const onnx::TensorProto& onnxTensor = node.attribute(0).t(); - //ONNX can have Float16 and double constant nodes but ArmNN only supports float32 - CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(), - static_cast(onnxTensor.data_type()), onnx::TensorProto::FLOAT); - //Register this as a m_ConstParam so we know we can use it as a constant param in future layers. m_TensorsInfo[node.output(0)].m_tensor = std::make_unique(onnxTensor); m_TensorsInfo[node.output(0)].m_info = std::make_unique(ToTensorInfo(onnxTensor)); m_TensorsInfo[node.output(0)].m_dtype = static_cast(onnxTensor.data_type()); - CreateConstantLayer(node.output(0), node.name()); + if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT) + { + CreateConstantLayer(node.output(0), node.name()); + } + else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64) + { + CreateInt64ConstantLayer(node.output(0), node.name()); + } + else + { + throw ParseException(fmt::format("Data type not support for Constant node '{}' {}", + node.name(), + CHECK_LOCATION().AsString())); + } } void OnnxParserImpl::ParseConv(const onnx::NodeProto& node) @@ -1622,6 +1702,29 @@ void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node) CreateReshapeLayer(node.input(0), node.output(0), node.name()); } +void OnnxParserImpl::ParseGather(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast(node.input_size()), 2); + CHECK_VALID_SIZE(static_cast(node.output_size()), 1); + + armnn::GatherDescriptor gatherDescriptor; + gatherDescriptor.m_Axis = static_cast(ReadOptionalNodeInt64Attribute(node, "axis", 0)); + + IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, { node.input(0), node.input(1) }); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, { node.output(0) }); +} + void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc = Pooling2dDescriptor(); -- cgit v1.2.1