// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "CommonTestUtils.hpp" #include #include namespace{ armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo, const armnn::TensorInfo& indicesInfo, const armnn::TensorInfo& outputInfo, const std::vector& indicesData) { armnn::INetworkPtr net(armnn::INetwork::Create()); armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0); armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData)); armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer("gather"); armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output"); Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0); Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1); Connect(gatherLayer, outputLayer, outputInfo, 0, 0); return net; } template> void GatherEndToEnd(const std::vector& backends) { armnn::TensorInfo paramsInfo({ 8 }, ArmnnType); armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32); armnn::TensorInfo outputInfo({ 3 }, ArmnnType); paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); outputInfo.SetQuantizationScale(1.0f); outputInfo.SetQuantizationOffset(0); // Creates structures for input & output. std::vector paramsData{ 1, 2, 3, 4, 5, 6, 7, 8 }; std::vector indicesData{ 7, 6, 5 }; std::vector expectedOutput{ 8, 7, 6 }; // Builds up the structure of the network armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData); BOOST_TEST_CHECKPOINT("create a network"); std::map> inputTensorData = {{ 0, paramsData }}; std::map> expectedOutputData = {{ 0, expectedOutput }}; EndToEndLayerTestImpl(move(net), inputTensorData, expectedOutputData, backends); } template> void GatherMultiDimEndToEnd(const std::vector& backends) { armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType); armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32); armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType); paramsInfo.SetQuantizationScale(1.0f); paramsInfo.SetQuantizationOffset(0); outputInfo.SetQuantizationScale(1.0f); outputInfo.SetQuantizationOffset(0); // Creates structures for input & output. std::vector paramsData{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18 }; std::vector indicesData{ 1, 2, 1, 2, 1, 0 }; std::vector expectedOutput{ 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6 }; // Builds up the structure of the network armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData); BOOST_TEST_CHECKPOINT("create a network"); std::map> inputTensorData = {{ 0, paramsData }}; std::map> expectedOutputData = {{ 0, expectedOutput }}; EndToEndLayerTestImpl(move(net), inputTensorData, expectedOutputData, backends); } } // anonymous namespace