From f10b15a8946f39bdf3f60cebc59d2963069eedca Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 17 Sep 2021 21:08:57 +0100 Subject: IVGCVSW-6382 Add Gather operator support to ONNX parser * Add ParseGather to support Gather operator on ONNX * Add Support of int64 converted to int32 for constant * Add OnnxParserTestUtils * Refactor ValidateTensorShapesFromInputs of GatherLayer * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7 --- CMakeLists.txt | 2 + docs/01_01_parsers.dox | 6 + src/armnn/layers/GatherLayer.cpp | 38 ++- src/armnn/layers/GatherLayer.hpp | 5 + src/armnnOnnxParser/OnnxParser.cpp | 115 ++++++++- src/armnnOnnxParser/OnnxParser.hpp | 6 + src/armnnOnnxParser/test/Gather.cpp | 315 +++++++++++++++++++++++ src/armnnOnnxParser/test/OnnxParserTestUtils.hpp | 21 ++ 8 files changed, 488 insertions(+), 20 deletions(-) create mode 100644 src/armnnOnnxParser/test/Gather.cpp create mode 100644 src/armnnOnnxParser/test/OnnxParserTestUtils.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 421afb6d18..69a68274be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -758,7 +758,9 @@ if(BUILD_UNIT_TESTS) src/armnnOnnxParser/test/DepthConv.cpp src/armnnOnnxParser/test/Flatten.cpp src/armnnOnnxParser/test/FullyConnected.cpp + src/armnnOnnxParser/test/Gather.cpp src/armnnOnnxParser/test/GetInputsOutputs.cpp + src/armnnOnnxParser/test/OnnxParserTestUtils.hpp src/armnnOnnxParser/test/Pooling.cpp src/armnnOnnxParser/test/ProtoxtFixture.cpp src/armnnOnnxParser/test/Relu.cpp diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox index 97497fe016..31b7687a7d 100644 --- a/docs/01_01_parsers.dox +++ b/docs/01_01_parsers.dox @@ -49,6 +49,9 @@ The Arm NN SDK ONNX parser currently only supports fp32 operators. - Flatten - See the ONNX [Flatten documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Flatten) for more information. +- Gather + - See the ONNX [Gather documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gather) for more information. + - GlobalAveragePool - See the ONNX [GlobalAveragePool documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#GlobalAveragePool) for more information. @@ -64,6 +67,9 @@ The Arm NN SDK ONNX parser currently only supports fp32 operators. - Reshape - See the ONNX [Reshape documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Reshape) for more information. +- Shape + - See the ONNX [Shape documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape) for more information. + - Sigmoid - See the ONNX [Sigmoid documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sigmoid) for more information. diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index 9a4f9bf8f0..cdbdaabcdc 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -31,16 +31,11 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const return CloneBase(graph, m_Param, GetName()); } -void GatherLayer::ValidateTensorShapesFromInputs() +std::vector GatherLayer::InferOutputShapes(const std::vector& inputShapes) const { - VerifyLayerConnections(2, CHECK_LOCATION()); - - const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); - - VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); - - const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo(); - const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo(); + ARMNN_ASSERT(inputShapes.size() == 2); + const TensorShape& params = inputShapes[0]; + const TensorShape& indices = inputShapes[1]; const unsigned int paramsDim = params.GetNumDimensions(); const unsigned int indicesDim = indices.GetNumDimensions(); @@ -57,20 +52,35 @@ void GatherLayer::ValidateTensorShapesFromInputs() for (unsigned int i = 0; i < axis; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } for (unsigned int i = axis; i < indicesDim + axis; ++i) { - dimSizes.push_back(indices.GetShape()[i - axis]); + dimSizes.push_back(indices[i - axis]); } for (unsigned int i = 1 + axis; i < paramsDim; ++i) { - dimSizes.push_back(params.GetShape()[i]); + dimSizes.push_back(params[i]); } - const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + return std::vector({ TensorShape({outputDim, dimSizes.data()})}); +} + +void GatherLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + + VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); + + std::vector inferredShapes = InferOutputShapes( + {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()}); + ARMNN_ASSERT(inferredShapes.size() == 1); + ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified); - ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "GatherLayer"); + ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer"); } void GatherLayer::Accept(ILayerVisitor& visitor) const diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp index 010af37b49..3bc8c69bc4 100644 --- a/src/armnn/layers/GatherLayer.hpp +++ b/src/armnn/layers/GatherLayer.hpp @@ -24,6 +24,11 @@ public: /// @param [in] graph The graph into which this layer is being cloned. GatherLayer* Clone(Graph& graph) const override; + /// Infers the output shapes from given input shapes and layer properties. + /// @param [in] inputShapes The input shapes layer has. + /// @return A vector to the inferred output shape. + std::vector InferOutputShapes(const std::vector& inputShapes) const override; + /// Check if the input tensor shape(s). /// will lead to a valid configuration of @ref GatherLayer. /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate. diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 889c35f391..e70eb64047 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -427,7 +427,8 @@ const std::map OnnxParser { "Conv", &OnnxParserImpl::ParseConv }, { "Add", &OnnxParserImpl::ParseAdd }, { "Flatten", &OnnxParserImpl::ParseFlatten }, - { "Shape", &OnnxParserImpl::ParseShape } + { "Shape", &OnnxParserImpl::ParseShape }, + { "Gather", &OnnxParserImpl::ParseGather }, }; template @@ -533,6 +534,10 @@ OnnxParserImpl::CreateConstTensor(const std::string name, TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; + //ONNX can have Float16 and double constant nodes but ArmNN only supports float32 + CHECK_VALID_DATATYPE(name, onnxTensor.name(), + static_cast(onnxTensor.data_type()), onnx::TensorProto::FLOAT); + // Makes sure IsConstant flag is set. tensorInfo.SetConstant(); @@ -568,6 +573,65 @@ OnnxParserImpl::CreateConstTensor(const std::string name, } } +std::pair> +OnnxParserImpl::CreateInt64ConstTensor(const std::string name, + armnn::Optional permutationVector) +{ + TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; + onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; + + CHECK_VALID_DATATYPE(name, onnxTensor.name(), + static_cast(onnxTensor.data_type()), onnx::TensorProto::INT64); + + // Makes sure IsConstant flag is set. + tensorInfo.SetConstant(); + uint numElements = tensorInfo.GetNumElements(); + + // Const tensors requires at least a list of values + if (numElements == 0) + { + throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}", + name, + CHECK_LOCATION().AsString())); + } + + // Copy the value list entries into the destination + if (!onnxTensor.has_raw_data()) + { + auto srcData = onnxTensor.int64_data().data(); + if(numElements != static_cast(onnxTensor.int64_data_size())) + { + throw ParseException( + fmt::format("The number of data provided ({}) does not match the tensor '{}' number of " + "elements ({}) {}", + onnxTensor.int64_data_size(), + name, + tensorInfo.GetNumElements(), + CHECK_LOCATION().AsString())); + } + + std::vector int32Data; + for(uint i = 0; i < numElements; i++) + { + int32_t int32Value = CHECKED_INT32(srcData[i]); + int32Data.push_back(int32Value); + } + + return CreateConstTensorImpl(int32Data.data(), tensorInfo, permutationVector); + } + else + { + auto srcData = reinterpret_cast(onnxTensor.raw_data().c_str()); + std::vector int32Data; + for(uint i = 0; i < numElements; i++) + { + int32_t int32Value = CHECKED_INT32(srcData[i]); + int32Data.push_back(int32Value); + } + return CreateConstTensorImpl(int32Data.data(), tensorInfo, permutationVector); + } +} + ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile) { FILE* fd = fopen(graphFile, "r"); @@ -1152,7 +1216,14 @@ std::pair OnnxParserImpl::AddPrepareBroadcast(const st void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) { auto armnnTensor = CreateConstTensor(tensorName); + IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo()); + RegisterOutputSlots(layer, {tensorName}); +} +void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName) +{ + auto armnnTensor = CreateInt64ConstTensor(tensorName); IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str()); layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo()); RegisterOutputSlots(layer, {tensorName}); @@ -1370,16 +1441,25 @@ void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) } const onnx::TensorProto& onnxTensor = node.attribute(0).t(); - //ONNX can have Float16 and double constant nodes but ArmNN only supports float32 - CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(), - static_cast(onnxTensor.data_type()), onnx::TensorProto::FLOAT); - //Register this as a m_ConstParam so we know we can use it as a constant param in future layers. m_TensorsInfo[node.output(0)].m_tensor = std::make_unique(onnxTensor); m_TensorsInfo[node.output(0)].m_info = std::make_unique(ToTensorInfo(onnxTensor)); m_TensorsInfo[node.output(0)].m_dtype = static_cast(onnxTensor.data_type()); - CreateConstantLayer(node.output(0), node.name()); + if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT) + { + CreateConstantLayer(node.output(0), node.name()); + } + else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64) + { + CreateInt64ConstantLayer(node.output(0), node.name()); + } + else + { + throw ParseException(fmt::format("Data type not support for Constant node '{}' {}", + node.name(), + CHECK_LOCATION().AsString())); + } } void OnnxParserImpl::ParseConv(const onnx::NodeProto& node) @@ -1622,6 +1702,29 @@ void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node) CreateReshapeLayer(node.input(0), node.output(0), node.name()); } +void OnnxParserImpl::ParseGather(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast(node.input_size()), 2); + CHECK_VALID_SIZE(static_cast(node.output_size()), 1); + + armnn::GatherDescriptor gatherDescriptor; + gatherDescriptor.m_Axis = static_cast(ReadOptionalNodeInt64Attribute(node, "axis", 0)); + + IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape(); + auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape }); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlots(layer, { node.input(0), node.input(1) }); + + // register the output connection slots for the layer, connections are made after all layers have been created + RegisterOutputSlots(layer, { node.output(0) }); +} + void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc = Pooling2dDescriptor(); diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index 101e99ff8d..b71b8dca49 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -98,6 +98,7 @@ private: void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc); void CreateConstantLayer(const std::string& tensorName, const std::string& layerName); + void CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName); void CreateReshapeLayer(const std::string& inputName, const std::string& outputName, const std::string& layerName); @@ -115,6 +116,7 @@ private: void ParseConstant(const onnx::NodeProto& nodeProto); void ParseConv(const onnx::NodeProto& nodeProto); void ParseFlatten(const onnx::NodeProto& node); + void ParseGather(const onnx::NodeProto& node); void ParseGlobalAveragePool(const onnx::NodeProto& node); void ParseMaxPool(const onnx::NodeProto& nodeProto); void ParseShape(const onnx::NodeProto& node); @@ -133,6 +135,10 @@ private: CreateConstTensor(const std::string name, armnn::Optional permutationVector = armnn::EmptyOptional()); + std::pair> + CreateInt64ConstTensor(const std::string name, + armnn::Optional permutationVector = armnn::EmptyOptional()); + template void ValidateInputs(const onnx::NodeProto& node, TypeList validInputs, diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp new file mode 100644 index 0000000000..1d214419c4 --- /dev/null +++ b/src/armnnOnnxParser/test/Gather.cpp @@ -0,0 +1,315 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "armnnOnnxParser/IOnnxParser.hpp" +#include "ParserPrototxtFixture.hpp" +#include "OnnxParserTestUtils.hpp" + +TEST_SUITE("OnnxParser_Gather") +{ + +struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture +{ + GatherMainFixture(const std::vector& indicesShape, + const std::vector& indices, + const std::vector& inputShape, + const std::vector& outputShape) + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + output: "indices" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + )" + ConstructIndicesString(indicesShape, indices) + R"( + name: "value" + } + type: TENSOR + } + } + node { + input: "input" + input: "indices" + output: "output" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + name: "gather-model" + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + })"; + } + std::string ConstructIndicesString(const std::vector& indicesShape, const std::vector& indices) + { + std::string shapeStr; + for (int i : indicesShape) + { + shapeStr = fmt::format(" {} dims: {}", shapeStr, i); + } + for (int i : indices) + { + shapeStr = fmt::format(" {} int64_data: {}", shapeStr, i); + } + return shapeStr; + } +}; + +struct Gather1dFixture : GatherMainFixture +{ + Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 }) + { + Setup(); + } +}; + +struct Gather2dFixture : GatherMainFixture +{ + Gather2dFixture() : GatherMainFixture({ 3 }, { 1, 3, 4 }, { 5, 2 }, { 3, 2 }) + { + Setup(); + } +}; + +struct Gather3dMultiIndicesFixture : GatherMainFixture +{ + Gather3dMultiIndicesFixture() : GatherMainFixture({ 2, 3 }, { 1, 2, 1, 2, 1, 0 }, { 3, 2, 3 }, { 2, 3, 2, 3 }) + { + Setup(); + } +}; + +struct Gather4dFixture : GatherMainFixture +{ + Gather4dFixture() : GatherMainFixture({ 3 }, { 0, 1, 3 }, { 5, 4, 3, 2 }, { 3, 4, 3, 2 }) + { + Setup(); + } +}; + +TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest") +{ + RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}}, + {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}}); +} + +TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest") +{ + RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, + {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); +} + +TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest") +{ + RunTest<3, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}}, + {{"output", { 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, + 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, + 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f }}}); +} + +TEST_CASE_FIXTURE(Gather4dFixture, "Gather4dTest") +{ + RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, + 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, + 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, + 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, + 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, + 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, + 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, + 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, + 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, + 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, + 96.0f, 97.0f, 98.0f, 99.0f, 100.0f, + 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, + 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, + 111.0f, 112.0f, 113.0f, 114.0f, 115.0f, + 116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}}, + {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, + 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, + 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, + 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, + 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, + 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, + 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, + 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}}); +} + +struct GatherRawDataFixture : public armnnUtils::ParserPrototxtFixture +{ + GatherRawDataFixture() + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + output: "indices" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 3 + data_type: 7 + raw_data: + "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000" + name: "value" + } + type: TENSOR + } + } + node { + input: "input" + input: "indices" + output: "output" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + name: "gather-model" + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } + })"; + Setup(); + } +}; + +TEST_CASE_FIXTURE(GatherRawDataFixture, "GatherRawDataTest") +{ + RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, + 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, + 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, + 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, + 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, + 56.0f, 57.0f, 58.0f, 59.0f, 60.0f, + 61.0f, 62.0f, 63.0f, 64.0f, 65.0f, + 66.0f, 67.0f, 68.0f, 69.0f, 70.0f, + 71.0f, 72.0f, 73.0f, 74.0f, 75.0f, + 76.0f, 77.0f, 78.0f, 79.0f, 80.0f, + 81.0f, 82.0f, 83.0f, 84.0f, 85.0f, + 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, + 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, + 96.0f, 97.0f, 98.0f, 99.0f, 100.0f, + 101.0f, 102.0f, 103.0f, 104.0f, 105.0f, + 106.0f, 107.0f, 108.0f, 109.0f, 110.0f, + 111.0f, 112.0f, 113.0f, 114.0f, 115.0f, + 116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}}, + {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, + 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, + 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, + 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, + 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, + 73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f, + 79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f, + 85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f, + 91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}}); +} + +} diff --git a/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp b/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp new file mode 100644 index 0000000000..4ed6543d28 --- /dev/null +++ b/src/armnnOnnxParser/test/OnnxParserTestUtils.hpp @@ -0,0 +1,21 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +namespace armnnUtils +{ + +std::string ConstructTensorShapeString(const std::vector& shape) +{ + std::string shapeStr; + for (int i : shape) + { + shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i); + } + return shapeStr; +} + +} // namespace armnnUtils \ No newline at end of file -- cgit v1.2.1