From 452274c86245082ce20563ede12b92af81dba38a Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 23 Sep 2021 16:12:19 +0100 Subject: IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018 --- src/armnn/layers/GatherLayer.cpp | 8 +- src/armnnOnnxParser/OnnxParser.cpp | 165 +++++++++++++++++---------------- src/armnnOnnxParser/OnnxParser.hpp | 8 +- src/armnnOnnxParser/test/Gather.cpp | 22 ++++- src/armnnOnnxParser/test/Shape.cpp | 56 +++++------ src/armnnOnnxParser/test/Unsqueeze.cpp | 14 +++ 6 files changed, 153 insertions(+), 120 deletions(-) diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index e8b67b8348..a808c42384 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -37,6 +37,11 @@ std::vector GatherLayer::InferOutputShapes(const std::vector({ TensorShape(Dimensionality::Scalar)}); + } + const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); const unsigned int outputDim = paramsDim - 1 + indicesDim; @@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs() {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); ARMNN_ASSERT(inferredShapes.size() == 1); - ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified || + inferredShapes[0].GetDimensionality() == Dimensionality::Scalar); ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 3fcb7ab603..6caf690935 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -262,38 +262,38 @@ std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const s armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector& shape, int data_type) { - DataType type; - switch(data_type) - { - case onnx::TensorProto::FLOAT: - { - type = DataType::Float32; - break; - } - case onnx::TensorProto::INT32: - case onnx::TensorProto::INT64: - { - type = DataType::Signed32; + DataType type; + switch(data_type) + { + case onnx::TensorProto::FLOAT: + { + type = DataType::Float32; break; - } - default: - { - throw ParseException( - fmt::format("'{}' is not a currently supported datatype for tensor {}." - " Supported dataTypes are FLOAT, INT32 and INT64. {}", - onnx::TensorProto::DataType_Name(static_cast(data_type)), - name, - CHECK_LOCATION().AsString() )); - } - } + } + case onnx::TensorProto::INT32: + case onnx::TensorProto::INT64: + { + type = DataType::Signed32; + break; + } + default: + { + throw ParseException( + fmt::format("'{}' is not a currently supported datatype for tensor {}." + " Supported dataTypes are FLOAT, INT32 and INT64. {}", + onnx::TensorProto::DataType_Name(static_cast(data_type)), + name, + CHECK_LOCATION().AsString() )); + } + } - // To avoid crashes by trivial tensors - if (shape.empty()) - { - return TensorInfo(TensorShape(), type); - } + // To avoid crashes by trivial tensors + if (shape.empty()) + { + return TensorInfo(TensorShape(Dimensionality::Scalar), type); + } - return TensorInfo(TensorShape(static_cast(shape.size()), shape.data()), type); + return TensorInfo(TensorShape(static_cast(shape.size()), shape.data()), type); } armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) @@ -305,11 +305,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value()))); } - if (shapeDims.empty()) - { - shapeDims.push_back(1); - } - return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type()); } @@ -322,11 +317,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor) shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim))); } - if (shapeDims.empty()) - { - shapeDims.push_back(1); - } - return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type()); } @@ -376,7 +366,8 @@ void CalcPadding(uint32_t inputSize, TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, const TensorShape& inShape, - const std::string& outName) + const std::string& outName, + DataType dataType = DataType::Float32) { std::vector targetDims; for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i) @@ -420,7 +411,7 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements; } TensorShape outShape = TensorShape{static_cast(outDims.size()), outDims.data()}; - return TensorInfo(outShape, DataType::Float32); + return TensorInfo(outShape, dataType); } } //namespace @@ -469,7 +460,8 @@ void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node, std::vector OnnxParserImpl::ComputeOutputInfo(std::vector outNames, const IConnectableLayer* layer, - std::vector inputShapes) + std::vector inputShapes, + const onnx::TensorProto::DataType& dataType) { ARMNN_ASSERT(! outNames.empty()); bool needCompute = std::any_of(outNames.begin(), @@ -478,25 +470,45 @@ std::vector OnnxParserImpl::ComputeOutputInfo(std::vector outInfo; - //if the output info(s) are not here, we need to compute them - std::vector inferredShapes; - if(needCompute) - { - inferredShapes = layer->InferOutputShapes(inputShapes); - ARMNN_ASSERT(inferredShapes.size() == outNames.size()); - } - for (uint i = 0; i < outNames.size(); ++i) - { - if(needCompute) - { - m_TensorsInfo[outNames[i]] = OnnxTensor(); - m_TensorsInfo[outNames[i]].m_info = std::make_unique( - TensorInfo(inferredShapes[i], DataType::Float32)); - } + std::vector outInfo; + //if the output info(s) are not here, we need to compute them + std::vector inferredShapes; + DataType armnnType = DataType::Float32; + if(needCompute) { + inferredShapes = layer->InferOutputShapes(inputShapes); + ARMNN_ASSERT(inferredShapes.size() == outNames.size()); + switch (dataType) { + case onnx::TensorProto::FLOAT: { + armnnType = DataType::Float32; + break; + } + case onnx::TensorProto::INT32: + case onnx::TensorProto::INT64: { + armnnType = DataType::Signed32; + break; + } + default: { + throw ParseException( + fmt::format("'{}' is not a currently supported datatype for {}." + " Supported dataTypes are FLOAT, INT32 and INT64. {}", + onnx::TensorProto::DataType_Name(static_cast(dataType)), + layer->GetName(), + CHECK_LOCATION().AsString())); + } + } + } + for (uint i = 0; i < outNames.size(); ++i) + { + if(needCompute) + { + m_TensorsInfo[outNames[i]] = OnnxTensor(); + m_TensorsInfo[outNames[i]].m_info = std::make_unique( + TensorInfo(inferredShapes[i], armnnType)); + m_TensorsInfo[outNames[i]].m_dtype = dataType; + } outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info); - } - return outInfo; + } + return outInfo; } OnnxParserImpl::OnnxParserImpl() @@ -1480,7 +1492,8 @@ void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node) IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str()); ARMNN_ASSERT(layer != nullptr); - auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes, + m_TensorsInfo[node.input(0)].m_dtype); layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); @@ -1774,9 +1787,10 @@ void OnnxParserImpl::ParseGather(const onnx::NodeProto& node) IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str()); ARMNN_ASSERT(layer != nullptr); - TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); - TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); - auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }); + const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }, + m_TensorsInfo[node.input(0)].m_dtype); layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); // register the input connection slots for the layer, connections are made after all layers have been created @@ -1823,16 +1837,11 @@ void OnnxParserImpl::ParseShape(const onnx::NodeProto& node) CHECK_VALID_SIZE(static_cast(node.input_size()), 1); CHECK_VALID_SIZE(static_cast(node.output_size()), 1); - // Output must be INT64 - CHECK_VALID_DATATYPE(node.name(), node.output(0), - m_TensorsInfo[node.output(0)].m_dtype, - onnx::TensorProto::INT64); - IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str()); ARMNN_ASSERT(layer != nullptr); TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); - auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64); layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); // register the input connection slots for the layer, connections are made after all layers have been created @@ -1900,10 +1909,6 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) CHECK_VALID_SIZE(armnn::numeric_cast(node.input_size()), 1, 2); CHECK_VALID_SIZE(armnn::numeric_cast(node.output_size()), 1); - CHECK_VALID_DATATYPE(node.name(), node.input(0), - m_TensorsInfo[node.input(0)].m_dtype, - onnx::TensorProto::FLOAT); //input - TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); std::vector dims; if (node.input_size() == 1 && node.attribute_size() > 0) @@ -1931,9 +1936,12 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) std::vector targetShape; - for(uint i = 0; i < inputShape.GetNumDimensions(); i++) + if (inputShape.GetDimensionality() != Dimensionality::Scalar) { - targetShape.push_back(inputShape[i]); + for(uint i = 0; i < inputShape.GetNumDimensions(); i++) + { + targetShape.push_back(inputShape[i]); + } } for(uint i = 0; i < dims.size(); i++) @@ -1941,9 +1949,10 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node) targetShape.insert(targetShape.begin() + armnn::numeric_cast(dims[i]), 1); } - auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast(targetShape.size()), - targetShape.data()), inputShape, node.output(0)); + auto outInfo = ComputeReshapeInfo(TensorShape(static_cast(targetShape.size()), targetShape.data()), + inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType()); m_TensorsInfo[node.output(0)].m_info = std::make_unique(outInfo); + m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype; CreateReshapeLayer(node.input(0), node.output(0), node.name()); } diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index 6a0fad0ec2..d388f501d4 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -74,9 +74,11 @@ private: void SetupInfo(const google::protobuf::RepeatedPtrField* list); - std::vector ComputeOutputInfo(std::vector outNames, - const armnn::IConnectableLayer* layer, - std::vector inputShapes); + std::vector ComputeOutputInfo( + std::vector outNames, + const armnn::IConnectableLayer* layer, + std::vector inputShapes, + const onnx::TensorProto::DataType& type = onnx::TensorProto::FLOAT); void DetectFullyConnected(); diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp index 1d214419c4..8fd9021ebc 100644 --- a/src/armnnOnnxParser/test/Gather.cpp +++ b/src/armnnOnnxParser/test/Gather.cpp @@ -85,6 +85,14 @@ struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", { 1.0f }}}); +} + TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest") { - RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, - {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}}); + RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}}); } TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest") { - RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, - {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); + RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, + {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); } TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest") diff --git a/src/armnnOnnxParser/test/Shape.cpp b/src/armnnOnnxParser/test/Shape.cpp index b033b2d8bf..e01c4b8b55 100644 --- a/src/armnnOnnxParser/test/Shape.cpp +++ b/src/armnnOnnxParser/test/Shape.cpp @@ -5,9 +5,11 @@ #include "armnnOnnxParser/IOnnxParser.hpp" #include "ParserPrototxtFixture.hpp" +#include "OnnxParserTestUtils.hpp" TEST_SUITE("OnnxParser_Shape") { + struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture { ShapeMainFixture(const std::string& inputType, @@ -31,7 +33,7 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture& shape) - { - std::string shapeStr; - for (int i : shape) - { - shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i); - } - return shapeStr; - } }; -struct ShapeValidFloatFixture : ShapeMainFixture +struct ShapeFloatFixture : ShapeMainFixture { - ShapeValidFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) { + ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) + { Setup(); } }; -struct ShapeValidIntFixture : ShapeMainFixture +struct ShapeIntFixture : ShapeMainFixture { - ShapeValidIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) { + ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) + { Setup(); } }; struct Shape3DFixture : ShapeMainFixture { - Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) { + Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) + { Setup(); } }; struct Shape2DFixture : ShapeMainFixture { - Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) { + Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) + { Setup(); } }; struct Shape1DFixture : ShapeMainFixture { - Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) { + Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) + { Setup(); } }; -struct ShapeInvalidFixture : ShapeMainFixture -{ - ShapeInvalidFixture() : ShapeMainFixture("1", "1", "4", { 1, 3, 1, 5 }) {} -}; - -TEST_CASE_FIXTURE(ShapeValidFloatFixture, "FloatValidShapeTest") +TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest") { - RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, + RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}}); } -TEST_CASE_FIXTURE(ShapeValidIntFixture, "IntValidShapeTest") +TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest") { - RunTest<2, int>({{"Input", { 0, 1, 2, 3, 4, + RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}}); } TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest") { - RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}}); } TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest") { - RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}}); + RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}}); } TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest") { - RunTest<2, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}}); -} - -TEST_CASE_FIXTURE(ShapeInvalidFixture, "IncorrectOutputDataShapeTest") -{ - CHECK_THROWS_AS(Setup(), armnn::ParseException); + RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}}); } } diff --git a/src/armnnOnnxParser/test/Unsqueeze.cpp b/src/armnnOnnxParser/test/Unsqueeze.cpp index 95a191e46b..7ba87bc680 100644 --- a/src/armnnOnnxParser/test/Unsqueeze.cpp +++ b/src/armnnOnnxParser/test/Unsqueeze.cpp @@ -77,6 +77,14 @@ struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture } }; +struct UnsqueezeScalarFixture : UnsqueezeFixture +{ + UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 }) + { + Setup(); + } +}; + TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest") { RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, @@ -107,6 +115,12 @@ TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest") 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); } +TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest") +{ + RunTest<1, float>({{"Input", { 1.0f }}}, + {{"Output", { 1.0f }}}); +} + struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture { UnsqueezeInputAxesFixture() -- cgit v1.2.1