ArmNN
 20.08
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 namespace {
16 // helper for setting the dimensions in prototxt
17 void dimsHelper(const std::vector<int>& dims, std::string& text){
18  for(unsigned int i = 0; i < dims.size(); ++i) {
19  text.append(R"(dim {
20  size: )");
21  text.append(std::to_string(dims[i]));
22  text.append(R"(
23  })");
24  }
25 }
26 
27 // helper for converting from integer to octal representation
28 void octalHelper(const std::vector<int>& indicesContent, std::string& text){
29  for(unsigned int i = 0; i < indicesContent.size(); ++i) {
30  text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
31  }
32 }
33 } // namespace
34 
35 struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
36 {
37  GatherFixture(const armnn::TensorShape& inputShape0,
38  const armnn::TensorShape& inputShape1,
39  const std::vector<int>& input1Content,
40  const std::vector<int>& input0Dims,
41  const std::vector<int>& input1Dims,
42  int axis = 0)
43  {
44  m_Prototext = R"(
45 node {
46  name: "input0"
47  op: "Placeholder"
48  attr {
49  key: "dtype"
50  value {
51  type: DT_FLOAT
52  }
53  }
54  attr {
55  key: "shape"
56  value {
57  shape {
58 )";
59  dimsHelper(input0Dims, m_Prototext);
60 
61  m_Prototext.append(R"(
62  }
63  }
64  }
65 }
66 node {
67  name: "input1"
68  op: "Const"
69  attr {
70  key: "dtype"
71  value {
72  type: DT_INT32
73  }
74  }
75  attr {
76  key: "value"
77  value {
78  tensor {
79  dtype: DT_INT32
80  tensor_shape {
81 )");
82  dimsHelper(input1Dims, m_Prototext);
83 
84  m_Prototext.append(R"(
85  }
86  tensor_content: ")");
87  octalHelper(input1Content, m_Prototext);
88  m_Prototext.append(R"("
89  }
90  }
91  }
92 }
93 node {
94  name: "output"
95  op: "Gather"
96  input: "input0"
97  input: "input1"
98  attr {
99  key: "Tindices"
100  value {
101  type: DT_INT32
102  }
103  }
104  attr {
105  key: "Tparams"
106  value {
107  type: DT_FLOAT
108  }
109  }
110  attr {
111  key: "axis"
112  value {
113  i: )");
114  m_Prototext += std::to_string(axis);
115 
116  m_Prototext.append(R"(
117  }
118  }
119 }
120  )");
121 
122  Setup({ { "input0", inputShape0 },
123  { "input1", inputShape1 } },
124  { "output" });
125 
126  }
127 };
128 
129 
130 struct GatherFixture1DParams1DIndices : public GatherFixture
131 {
132  GatherFixture1DParams1DIndices() : GatherFixture(
133  { 4, 1, 1, 1 },
134  { 4, 0, 0, 0 },
135  { 0, 2, 1, 3 },
136  { 4 },
137  { 4 },
138  0) {}
139 };
140 
141 struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
142 {
143  GatherFixture1DParamsMultiDimIndices() : GatherFixture(
144  { 4, 1, 1 },
145  { 2, 2, 1, 1 },
146  { 0, 1, 1, 3 },
147  { 4 },
148  { 2, 2 },
149  0) {}
150 };
151 
152 struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
153 {
154  GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
155  { 5, 2, 1 },
156  { 2, 1, 4 },
157  { 1, 3, 0, 2 },
158  { 5, 2 },
159  { 2, 2 },
160  0) {}
161 };
162 
163 BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
164 {
165  RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
166 
167  { { "output", { 1, 3, 2, 4 } } });
168 }
169 
170 BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
171 {
172  RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
173 
174  { { "output", { 1, 2, 2, 4 } } });
175 }
176 
177 BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
178 {
179  RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
180 
181  { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
182 }
183 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
std::string ConvertInt32ToOctalString(int value)
Converts an int value into the Prototxt octal representation.
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
Definition: Gather.cpp:163