From 1112b016e7ffad979b7bd0c8d54c9c679d4043e2 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 30 Sep 2021 12:10:50 +0100 Subject: IVGCVSW-6449 Add GEMM operator support to ONNX parser Signed-off-by: Narumol Prangnawarat Change-Id: I3c6979c72d44a15fb2dc3afc22ac30d1428684b0 --- CMakeLists.txt | 1 + docs/01_01_parsers.dox | 2 + src/armnnOnnxParser/OnnxParser.cpp | 188 ++++++++++++- src/armnnOnnxParser/OnnxParser.hpp | 4 + src/armnnOnnxParser/test/Gemm.cpp | 556 +++++++++++++++++++++++++++++++++++++ 5 files changed, 750 insertions(+), 1 deletion(-) create mode 100644 src/armnnOnnxParser/test/Gemm.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b80dcadf52..8fd71239eb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -766,6 +766,7 @@ if(BUILD_UNIT_TESTS) src/armnnOnnxParser/test/Flatten.cpp src/armnnOnnxParser/test/FullyConnected.cpp src/armnnOnnxParser/test/Gather.cpp + src/armnnOnnxParser/test/Gemm.cpp src/armnnOnnxParser/test/GetInputsOutputs.cpp src/armnnOnnxParser/test/OnnxParserTestUtils.cpp src/armnnOnnxParser/test/OnnxParserTestUtils.hpp diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox index 2304e153bd..adc3051429 100644 --- a/docs/01_01_parsers.dox +++ b/docs/01_01_parsers.dox @@ -88,6 +88,8 @@ The Arm NN SDK ONNX parser currently only supports fp32 operators. - The parser only supports 2D convolutions with a group = 1 or group = #Nb_of_channel (depthwise convolution) - BatchNormalization - The parser does not support training mode. See the ONNX [BatchNormalization documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization) for more information. +- Gemm + - The parser only supports constant bias or non-constant bias where bias dimension = 1. See the ONNX [Gemm documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Gemm) for more information. - MatMul - The parser only supports constant weights in a fully connected layer. diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 6caf690935..3588975897 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -434,7 +434,8 @@ const std::map OnnxParser { "Shape", &OnnxParserImpl::ParseShape }, { "Gather", &OnnxParserImpl::ParseGather }, { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze }, - { "Concat", &OnnxParserImpl::ParseConcat } + { "Concat", &OnnxParserImpl::ParseConcat }, + { "Gemm", &OnnxParserImpl::ParseGemm } }; template @@ -1800,6 +1801,175 @@ void OnnxParserImpl::ParseGather(const onnx::NodeProto& node) RegisterOutputSlots(layer, { node.output(0) }); } +void OnnxParserImpl::ParseGemm(const onnx::NodeProto& node) +{ + CHECK_VALID_SIZE(static_cast(node.input_size()), 2, 3); + CHECK_VALID_SIZE(static_cast(node.output_size()), 1); + + int transA = static_cast(ReadOptionalNodeUint32Attribute(node, "transA", 0)); + int transB = static_cast(ReadOptionalNodeUint32Attribute(node, "transB", 0)); + float alpha = ReadOptionalNodeFloatAttribute(node, "alpha", 1.0); + float beta = ReadOptionalNodeFloatAttribute(node, "beta", 1.0); + bool biasEnabled = node.input_size() == 3; + + TensorShape input0Shape = m_TensorsInfo[node.input(0)].m_info->GetShape(); + TensorShape input1Shape = m_TensorsInfo[node.input(1)].m_info->GetShape(); + + // if transB != 0, add transpose to the input1 (tanspose weight matrix in FullyConnected) + armnn::FullyConnectedDescriptor fullyConnectedDescriptor; + fullyConnectedDescriptor.m_BiasEnabled = biasEnabled; + fullyConnectedDescriptor.m_TransposeWeightMatrix = transB; + + IConnectableLayer* layer = nullptr; + + // Just add a FullyConnected layer, weights and biases are handled as inputs now. + layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor, node.name().c_str()); + ARMNN_ASSERT(layer != nullptr); + + // if transA != 0, add transpose to the input0 + if (transA != 0) + { + std::string transAName = "transpose_" + node.input(0); + armnn::TransposeDescriptor transposeADescriptor; + transposeADescriptor.m_DimMappings = { 1, 0 }; + IConnectableLayer* transALayer = m_Network->AddTransposeLayer(transposeADescriptor, transAName.c_str()); + ARMNN_ASSERT(transALayer != nullptr); + auto transAInfo = ComputeOutputInfo({ transAName }, transALayer, { input0Shape }); + transALayer->GetOutputSlot(0).SetTensorInfo(transAInfo[0]); + transALayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u)); + // register the input connection slots for the layer, connections are made after all layers have been created + RegisterInputSlot(transALayer, node.input(0), 0); + input0Shape = transAInfo[0].GetShape(); + } + else + { + RegisterInputSlot(layer, node.input(0), 0); + } + + // Add constant layer to store weights/biases and connect to FullyConnected layer. + if(m_TensorsInfo[node.input(1)].isConstant()) + { + IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(1)).first); + TensorInfo weightInfo = *m_TensorsInfo[node.input(1)].m_info; + weightInfo.SetConstant(); + weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo); + + // if alpha != 1, multiply to the weight + if (alpha != 1) + { + std::string activationName = "activation_" + node.input(1); + armnn::ActivationDescriptor activationDescriptor; + activationDescriptor.m_A = alpha; + activationDescriptor.m_Function = ActivationFunction::Linear; + IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str()); + ARMNN_ASSERT(actLayer != nullptr); + + auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { weightInfo.GetShape() }); + actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]); + actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u)); + weightsLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u)); + input1Shape = actInfo[0].GetShape(); + } + else + { + weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u)); + input1Shape = weightInfo.GetShape(); + } + } + else + { + // if alpha != 1, multiply to the weight + if (alpha != 1) + { + std::string activationName = "activation_" + node.input(1); + armnn::ActivationDescriptor activationDescriptor; + activationDescriptor.m_A = alpha; + activationDescriptor.m_Function = ActivationFunction::Linear; + IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str()); + ARMNN_ASSERT(actLayer != nullptr); + + auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { input1Shape }); + actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]); + actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u)); + RegisterInputSlot(actLayer, node.input(1), 0); + input1Shape = actInfo[0].GetShape(); + } + else + { + RegisterInputSlot(layer, node.input(1), 1); + } + } + + if(biasEnabled && m_TensorsInfo[node.input(2)].isConstant()) + { + To1DTensor(node.input(2), CHECK_LOCATION()); + IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(2)).first); + TensorInfo biasInfo = *m_TensorsInfo[node.input(2)].m_info; + biasInfo.SetConstant(); + biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo); + + // if beta != 1, multiply to the bias + if (beta != 1) + { + std::string activationName = "activation_" + node.input(2); + armnn::ActivationDescriptor activationDescriptor; + activationDescriptor.m_A = beta; + activationDescriptor.m_Function = ActivationFunction::Linear; + IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str()); + ARMNN_ASSERT(actLayer != nullptr); + + auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { biasInfo.GetShape() }); + actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]); + actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u)); + biasLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u)); + } + else + { + biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u)); + } + } + else if (biasEnabled) + { + // Currently we support non-constant tensor of input C (bias) of Gemm when the dimension is 1 + if (m_TensorsInfo[node.input(2)].m_info->GetNumDimensions() != 1) + { + throw ParseException(fmt::format("The parser supports constant or non-constant with 1 dimension for " + "Input C of Gemm. Input '{}' in '{}' is not supported '{}'", + node.input(2), + node.name(), + CHECK_LOCATION().AsString())); + } + // if beta != 1, multiply to the bias + if (beta != 1) + { + std::string activationName = "activation_" + node.input(2); + armnn::ActivationDescriptor activationDescriptor; + activationDescriptor.m_A = beta; + activationDescriptor.m_Function = ActivationFunction::Linear; + IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str()); + ARMNN_ASSERT(actLayer != nullptr); + + auto actInfo = ComputeOutputInfo({ activationName }, + actLayer, + { m_TensorsInfo[node.input(2)].m_info->GetShape() }); + actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]); + actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u)); + RegisterInputSlot(actLayer, node.input(2), 0); + } + else + { + RegisterInputSlot(layer, node.input(2), 2); + } + } + + // Set final output of the FullyConnected layer + auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer, + { input0Shape, input1Shape }); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]); + + RegisterOutputSlots(layer, {node.output(0)}); +} + void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc = Pooling2dDescriptor(); @@ -2031,6 +2201,22 @@ void OnnxParserImpl::SetupOutputLayers() } } +void OnnxParserImpl::RegisterInputSlot(IConnectableLayer* layer, + const std::string& tensorId, + unsigned int slotIndex) +{ + armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex)); + + auto it = m_TensorConnections.find(tensorId); + + if (it == m_TensorConnections.end()) + { + //First time seing this tensor, we need to map it + m_TensorConnections[tensorId] = TensorSlots(); + } + m_TensorConnections[tensorId].inputSlots.push_back(slot); +} + void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector& tensorIds) { ARMNN_ASSERT(layer != nullptr); diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index d388f501d4..ec19006be7 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -120,12 +120,16 @@ private: void ParseConv(const onnx::NodeProto& nodeProto); void ParseFlatten(const onnx::NodeProto& node); void ParseGather(const onnx::NodeProto& node); + void ParseGemm(const onnx::NodeProto& node); void ParseGlobalAveragePool(const onnx::NodeProto& node); void ParseMaxPool(const onnx::NodeProto& nodeProto); void ParseShape(const onnx::NodeProto& node); void ParseReshape(const onnx::NodeProto& nodeProto); void ParseUnsqueeze(const onnx::NodeProto& nodeProto); + void RegisterInputSlot(armnn::IConnectableLayer* layer, + const std::string& tensorId, + unsigned int slotIndex); void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector& tensorIndexes); void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector& tensorIndexes); diff --git a/src/armnnOnnxParser/test/Gemm.cpp b/src/armnnOnnxParser/test/Gemm.cpp new file mode 100644 index 0000000000..f68758f42e --- /dev/null +++ b/src/armnnOnnxParser/test/Gemm.cpp @@ -0,0 +1,556 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "armnnOnnxParser/IOnnxParser.hpp" +#include "ParserPrototxtFixture.hpp" +#include "OnnxParserTestUtils.hpp" + +TEST_SUITE("OnnxParser_Gemm") +{ + +struct GemmFixture : public armnnUtils::ParserPrototxtFixture +{ + GemmFixture(const std::string& alpha, + const std::string& beta, + const std::string& transA, + const std::string& transB, + const std::vector& inputAShape, + const std::vector& inputBShape, + const std::vector& inputCShape, + const std::vector& outputShape) + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + input: "A" + input: "B" + input: "C" + output: "Output" + op_type: "Gemm" + attribute { + name: "alpha" + f: )" + alpha + R"( + type: FLOAT + } + attribute { + name: "beta" + f: )" + beta + R"( + type: FLOAT + } + attribute { + name: "transA" + i: )" + transA + R"( + type: INT + } + attribute { + name: "transB" + i: )" + transB + R"( + type: INT + } + } + name: "gem-model" + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"( + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"( + } + } + } + } + input { + name: "C" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"( + } + } + } + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + })"; + } +}; + +struct GemmAllAttributesFixture : GemmFixture +{ + GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 }) + { + Setup(); + } +}; + +struct GemmSimpleFixture : GemmFixture +{ + GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 }) + { + Setup(); + } +}; + +struct GemmTransAFixture : GemmFixture +{ + GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 }) + { + Setup(); + } +}; + +struct GemmTransBFixture : GemmFixture +{ + GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 }) + { + Setup(); + } +}; + +struct GemmParseExceptionFixture : GemmFixture +{ + GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {} +}; + +TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, + {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, + {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, + {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f, + 12.535f, 38.57f, 64.605f, 90.64f, 116.675f, + 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}}); +} + +TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, + {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, + {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, + {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f, + 196.1f, 222.2f, 248.3f, 274.4f, 300.5f, + 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}}); +} + +TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, + {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, + {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, + {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f, + 146.1f, 172.2f, 198.3f, 224.4f, 250.5f, + 112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}}); +} + +TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, + {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, + {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, + {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f, + 60.1f, 164.2f, 268.3f, 372.4f, 476.5f, + 20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}}); +} + +TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest") +{ + // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension) + CHECK_THROWS_AS(Setup(), armnn::ParseException); +} + +struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture +{ + GemmConstantFixture() + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + input: "A" + input: "B" + input: "C" + output: "Output" + op_type: "Gemm" + attribute { + name: "alpha" + f: 0.25 + type: FLOAT + } + attribute { + name: "beta" + f: 0.35 + type: FLOAT + } + attribute { + name: "transA" + i: 1 + type: INT + } + attribute { + name: "transB" + i: 1 + type: INT + } + } + name: "gem-model" + initializer { + dims: 5 + dims: 4 + data_type: 1 + float_data: 1.0 + float_data: 2.0 + float_data: 3.0 + float_data: 4.0 + float_data: 5.0 + float_data: 6.0 + float_data: 7.0 + float_data: 8.0 + float_data: 9.0 + float_data: 10.0 + float_data: 11.0 + float_data: 12.0 + float_data: 13.0 + float_data: 14.0 + float_data: 15.0 + float_data: 16.0 + float_data: 17.0 + float_data: 18.0 + float_data: 19.0 + float_data: 20.0 + name: "B" + } + initializer { + dims: 1 + dims: 5 + data_type: 1 + float_data: 0.1 + float_data: 0.2 + float_data: 0.3 + float_data: 0.4 + float_data: 0.5 + name: "C" + } + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 5 + } + } + } + } + } + })"; + Setup(); + } +}; + +TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}}, + {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f, + 12.535f, 38.57f, 64.605f, 90.64f, 116.675f, + 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}}); +} + +struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture +{ + GemmConstantSimpleFixture() + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + input: "A" + input: "B" + input: "C" + output: "Output" + op_type: "Gemm" + attribute { + name: "alpha" + f: 1 + type: FLOAT + } + attribute { + name: "beta" + f: 1 + type: FLOAT + } + attribute { + name: "transA" + i: 0 + type: INT + } + attribute { + name: "transB" + i: 0 + type: INT + } + } + name: "gem-model" + initializer { + dims: 4 + dims: 5 + data_type: 1 + float_data: 1.0 + float_data: 2.0 + float_data: 3.0 + float_data: 4.0 + float_data: 5.0 + float_data: 6.0 + float_data: 7.0 + float_data: 8.0 + float_data: 9.0 + float_data: 10.0 + float_data: 11.0 + float_data: 12.0 + float_data: 13.0 + float_data: 14.0 + float_data: 15.0 + float_data: 16.0 + float_data: 17.0 + float_data: 18.0 + float_data: 19.0 + float_data: 20.0 + name: "B" + } + initializer { + dims: 1 + dims: 5 + data_type: 1 + float_data: 0.1 + float_data: 0.2 + float_data: 0.3 + float_data: 0.4 + float_data: 0.5 + name: "C" + } + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 5 + } + } + } + } + } + })"; + Setup(); + } +}; + +TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}}, + {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f, + 196.1f, 222.2f, 248.3f, 274.4f, 300.5f, + 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}}); +} + +struct GemmABFixture : public armnnUtils::ParserPrototxtFixture +{ + GemmABFixture(const std::string& alpha, + const std::string& beta, + const std::string& transA, + const std::string& transB, + const std::vector& inputAShape, + const std::vector& inputBShape, + const std::vector& outputShape) + { + m_Prototext = R"( + ir_version: 8 + producer_name: "onnx-example" + graph { + node { + input: "A" + input: "B" + output: "Output" + op_type: "Gemm" + attribute { + name: "alpha" + f: )" + alpha + R"( + type: FLOAT + } + attribute { + name: "beta" + f: )" + beta + R"( + type: FLOAT + } + attribute { + name: "transA" + i: )" + transA + R"( + type: INT + } + attribute { + name: "transB" + i: )" + transB + R"( + type: INT + } + } + name: "gem-model" + input { + name: "A" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"( + } + } + } + } + input { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"( + } + } + } + } + output { + name: "Output" + type { + tensor_type { + elem_type: 1 + shape { + )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( + } + } + } + } + })"; + Setup(); + } +}; + +struct GemmAlphaTransAFixture : GemmABFixture +{ + GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {} +}; + +struct GemmAlphaTransBFixture : GemmABFixture +{ + GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {} +}; + +TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, + {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}}, + {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f, + 36.5f, 43.0f, 49.5f, 56.0f, 62.5f, + 28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}}); +} + +TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest") +{ + RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, + 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, + {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}}, + {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f, + 15.0f, 41.0f, 67.0f, 93.0f, 119.0f, + 5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}}); +} + +} -- cgit v1.2.1