ArmNN
 22.08
GatherNdEndToEndTestImpl.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <CommonTestUtils.hpp>
9 
10 #include <armnn/INetwork.hpp>
11 #include <ResolveType.hpp>
12 
13 #include <doctest/doctest.h>
14 
15 namespace{
16 
17 armnn::INetworkPtr CreateGatherNdNetwork(const armnn::TensorInfo& paramsInfo,
18  const armnn::TensorInfo& indicesInfo,
19  const armnn::TensorInfo& outputInfo,
20  const std::vector<int32_t>& indicesData)
21 {
23 
24  armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
25  armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
26  armnn::IConnectableLayer* gatherNdLayer = net->AddGatherNdLayer("gatherNd");
27  armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
28  Connect(paramsLayer, gatherNdLayer, paramsInfo, 0, 0);
29  Connect(indicesLayer, gatherNdLayer, indicesInfo, 0, 1);
30  Connect(gatherNdLayer, outputLayer, outputInfo, 0, 0);
31 
32  return net;
33 }
34 
35 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
36 void GatherNdEndToEnd(const std::vector<BackendId>& backends)
37 {
38  armnn::TensorInfo paramsInfo({ 2, 3, 8, 4 }, ArmnnType);
39  armnn::TensorInfo indicesInfo({ 2, 2 }, armnn::DataType::Signed32);
40  armnn::TensorInfo outputInfo({ 2, 8, 4 }, ArmnnType);
41 
42  paramsInfo.SetQuantizationScale(1.0f);
43  paramsInfo.SetQuantizationOffset(0);
44  paramsInfo.SetConstant(true);
45  indicesInfo.SetConstant(true);
46  outputInfo.SetQuantizationScale(1.0f);
47  outputInfo.SetQuantizationOffset(0);
48 
49  // Creates structures for input & output.
50  std::vector<T> paramsData{
51  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
52  16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
53 
54  32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
55  48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
56 
57  64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
58  80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
59 
60  96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
61  112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
62 
63  128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
64  144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
65 
66  160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
67  176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191
68  };
69 
70  std::vector<int32_t> indicesData{
71  { 1, 2, 1, 1},
72  };
73 
74  std::vector<T> expectedOutput{
75  160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
76  176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
77 
78  128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
79  144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159
80  };
81 
82  // Builds up the structure of the network
83  armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
84 
85  CHECK(net);
86 
87  std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
88  std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
89 
90  EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
91 }
92 
93 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
94 void GatherNdMultiDimEndToEnd(const std::vector<BackendId>& backends)
95 {
96  armnn::TensorInfo paramsInfo({ 5, 5, 2 }, ArmnnType);
97  armnn::TensorInfo indicesInfo({ 2, 2, 3, 2 }, armnn::DataType::Signed32);
98  armnn::TensorInfo outputInfo({ 2, 2, 3, 2 }, ArmnnType);
99 
100  paramsInfo.SetQuantizationScale(1.0f);
101  paramsInfo.SetQuantizationOffset(0);
102  paramsInfo.SetConstant(true);
103  indicesInfo.SetConstant(true);
104  outputInfo.SetQuantizationScale(1.0f);
105  outputInfo.SetQuantizationOffset(0);
106 
107  // Creates structures for input & output.
108  std::vector<T> paramsData{
109  0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
110  10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
111  20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
112  30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
113  40, 41, 42, 43, 44, 45, 46, 47, 48, 49
114  };
115 
116  std::vector<int32_t> indicesData{
117  0, 0,
118  3, 3,
119  4, 4,
120 
121  0, 0,
122  1, 1,
123  2, 2,
124 
125  4, 4,
126  3, 3,
127  0, 0,
128 
129  2, 2,
130  1, 1,
131  0, 0
132  };
133 
134  std::vector<T> expectedOutput{
135  0, 1,
136  36, 37,
137  48, 49,
138 
139  0, 1,
140  12, 13,
141  24, 25,
142 
143  48, 49,
144  36, 37,
145  0, 1,
146 
147  24, 25,
148  12, 13,
149  0, 1
150  };
151 
152  // Builds up the structure of the network
153  armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
154 
155  std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
156  std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
157 
158  EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
159 }
160 
161 } // anonymous namespace
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition: INetwork.hpp:68
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:327
void SetQuantizationScale(float scale)
Definition: Tensor.cpp:473
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition: Tensor.cpp:514
void SetQuantizationOffset(int32_t offset)
Definition: Tensor.cpp:489
void Connect(armnn::IConnectableLayer *from, armnn::IConnectableLayer *to, const armnn::TensorInfo &tensorInfo, unsigned int fromIndex, unsigned int toIndex)
Definition: TestUtils.cpp:14
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition: INetwork.hpp:238
static INetworkPtr Create(NetworkOptions networkOptions={})
Definition: Network.cpp:475