diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-09-23 16:12:19 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-07 14:43:09 +0000 |
commit | 452274c86245082ce20563ede12b92af81dba38a (patch) | |
tree | 79718c6cf86acbb21138068c17aae15c4b172306 /src/armnnOnnxParser/OnnxParser.cpp | |
parent | 4d217c02fe2c0a32ff9da69d8fe375a75173c0f3 (diff) | |
download | armnn-452274c86245082ce20563ede12b92af81dba38a.tar.gz |
IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 165 |
1 files changed, 87 insertions, 78 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 3fcb7ab603..6caf690935 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -262,38 +262,38 @@ std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const s armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type) { - DataType type; - switch(data_type) - { - case onnx::TensorProto::FLOAT: - { - type = DataType::Float32; - break; - } - case onnx::TensorProto::INT32: - case onnx::TensorProto::INT64: - { - type = DataType::Signed32; + DataType type; + switch(data_type) + { + case onnx::TensorProto::FLOAT: + { + type = DataType::Float32; break; - } - default: - { - throw ParseException( - fmt::format("'{}' is not a currently supported datatype for tensor {}." - " Supported dataTypes are FLOAT, INT32 and INT64. {}", - onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)), - name, - CHECK_LOCATION().AsString() )); - } - } + } + case onnx::TensorProto::INT32: + case onnx::TensorProto::INT64: + { + type = DataType::Signed32; + break; + } + default: + { + throw ParseException( + fmt::format("'{}' is not a currently supported datatype for tensor {}." + " Supported dataTypes are FLOAT, INT32 and INT64. {}", + onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)), + name, + CHECK_LOCATION().AsString() )); + } + } - // To avoid crashes by trivial tensors - if (shape.empty()) - { - return TensorInfo(TensorShape(), type); - } + // To avoid crashes by trivial tensors + if (shape.empty()) + { + return TensorInfo(TensorShape(Dimensionality::Scalar), type); + } - return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type); + return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type); } armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) @@ -305,11 +305,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value()))); } - if (shapeDims.empty()) - { - shapeDims.push_back(1); - } - return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type()); } @@ -322,11 +317,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor) shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim))); } - if (shapeDims.empty()) - { - shapeDims.push_back(1); - } - return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type()); } @@ -376,7 +366,8 @@ void CalcPadding(uint32_t inputSize, TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, const TensorShape& inShape, - const std::string& outName) + const std::string& outName, + DataType dataType = DataType::Float32) { std::vector<int> targetDims; for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i) @@ -420,7 +411,7 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements; } TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()}; - return TensorInfo(outShape, DataType::Float32); + return TensorInfo(outShape, dataType); } } //namespace @@ -469,7 +460,8 @@ void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node, std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames, const IConnectableLayer* layer, - std::vector<TensorShape> inputShapes) + std::vector<TensorShape> inputShapes, + const onnx::TensorProto::DataType& dataType) { ARMNN_ASSERT(! outNames.empty()); bool needCompute = std::any_of(outNames.begin(), @@ -478,25 +470,45 @@ std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::strin { return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr); }); - std::vector<TensorInfo> outInfo; - //if the output info(s) are not here, we need to compute them - std::vector<TensorShape> inferredShapes; - if(needCompute) - { - inferredShapes = layer->InferOutputShapes(inputShapes); - ARMNN_ASSERT(inferredShapes.size() == outNames.size()); - } - for (uint i = 0; i < outNames.size(); ++i) - { - if(needCompute) - { - m_TensorsInfo[outNames[i]] = OnnxTensor(); - m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>( - TensorInfo(inferredShapes[i], DataType::Float32)); - } + std::vector<TensorInfo> outInfo; + //if the output info(s) are not here, we need to compute them + std::vector<TensorShape> inferredShapes; + DataType armnnType = DataType::Float32; + if(needCompute) { + inferredShapes = layer->InferOutputShapes(inputShapes); + ARMNN_ASSERT(inferredShapes.size() == outNames.size()); + switch (dataType) { + case onnx::TensorProto::FLOAT: { + armnnType = DataType::Float32; + break; + } + case onnx::TensorProto::INT32: + case onnx::TensorProto::INT64: { + armnnType = DataType::Signed32; + break; + } + default: { + throw ParseException( + fmt::format("'{}' is not a currently supported datatype for {}." + " Supported dataTypes are FLOAT, INT32 and INT64. {}", + onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(dataType)), + layer->GetName(), + CHECK_LOCATION().AsString())); + } + } + } + for (uint i = 0; i < outNames.size(); ++i) + { + if(needCompute) + { + m_TensorsInfo[outNames[i]] = OnnxTensor(); + m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>( + TensorInfo(inferredShapes[i], armnnType)); + m_TensorsInfo[outNames[i]].m_dtype = dataType; + } outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info); - } - return outInfo; + } + return outInfo; } OnnxParserImpl::OnnxParserImpl() @@ -1480,7 +1492,8 @@ void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node) IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str()); ARMNN_ASSERT(layer != nullptr); - auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes, + m_TensorsInfo[node.input(0)].m_dtype); layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); @@ -1774,9 +1787,10 @@ void OnnxParserImpl::ParseGather(const onnx::NodeProto& node) IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str()); ARMNN_ASSERT(layer != nullptr); - TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); - TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); - auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }); + const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }, + m_TensorsInfo[node.input(0)].m_dtype); layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); // register the input connection slots for the layer, connections are made after all layers have been created @@ -1823,16 +1837,11 @@ void OnnxParserImpl::ParseShape(const onnx::NodeProto& node) CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1); CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); - // Output must be INT64 - CHECK_VALID_DATATYPE(node.name(), node.output(0), - m_TensorsInfo[node.output(0)].m_dtype, - onnx::TensorProto::INT64); - IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str()); ARMNN_ASSERT(layer != nullptr); TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); - auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64); layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); // register the input connection slots for the layer, connections are made after all layers have been created @@ -1900,10 +1909,6 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2); CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1); - CHECK_VALID_DATATYPE(node.name(), node.input(0), - m_TensorsInfo[node.input(0)].m_dtype, - onnx::TensorProto::FLOAT); //input - TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); std::vector<uint32_t> dims; if (node.input_size() == 1 && node.attribute_size() > 0) @@ -1931,9 +1936,12 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) std::vector<unsigned int> targetShape; - for(uint i = 0; i < inputShape.GetNumDimensions(); i++) + if (inputShape.GetDimensionality() != Dimensionality::Scalar) { - targetShape.push_back(inputShape[i]); + for(uint i = 0; i < inputShape.GetNumDimensions(); i++) + { + targetShape.push_back(inputShape[i]); + } } for(uint i = 0; i < dims.size(); i++) @@ -1941,9 +1949,10 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1); } - auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast<unsigned int>(targetShape.size()), - targetShape.data()), inputShape, node.output(0)); + auto outInfo = ComputeReshapeInfo(TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()), + inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType()); m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo); + m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype; CreateReshapeLayer(node.input(0), node.output(0), node.name()); } |