ArmNN
 20.02
Gather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
7 
10 
11 #include <boost/test/unit_test.hpp>
12 
13 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 
15 namespace {
16 // helper for setting the dimensions in prototxt
17 void dimsHelper(const std::vector<int>& dims, std::string& text){
18  for(u_int i = 0; i < dims.size(); ++i) {
19  text.append(R"(dim {
20  size: )");
21  text.append(std::to_string(dims[i]));
22  text.append(R"(
23  })");
24  }
25 }
26 
27 // helper for converting from integer to octal representation
28 void octalHelper(const std::vector<int>& indicesContent, std::string& text){
29  for(unsigned int i = 0; i < indicesContent.size(); ++i) {
30  text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
31  }
32 }
33 } // namespace
34 
35 struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
36 {
37  GatherFixture(const armnn::TensorShape& inputShape0,
38  const armnn::TensorShape& inputShape1,
39  const std::vector<int>& input1Content,
40  const std::vector<int>& input0Dims,
41  const std::vector<int>& input1Dims)
42  {
43  m_Prototext = R"(
44 node {
45  name: "input0"
46  op: "Placeholder"
47  attr {
48  key: "dtype"
49  value {
50  type: DT_FLOAT
51  }
52  }
53  attr {
54  key: "shape"
55  value {
56  shape {
57 )";
58  dimsHelper(input0Dims, m_Prototext);
59  m_Prototext.append(R"(
60  }
61  }
62  }
63 }
64 node {
65  name: "input1"
66  op: "Const"
67  attr {
68  key: "dtype"
69  value {
70  type: DT_INT32
71  }
72  }
73  attr {
74  key: "value"
75  value {
76  tensor {
77  dtype: DT_INT32
78  tensor_shape {
79 )");
80  dimsHelper(input1Dims, m_Prototext);
81  m_Prototext.append(R"(
82  }
83  tensor_content: ")");
84  octalHelper(input1Content, m_Prototext);
85  m_Prototext.append(R"("
86  }
87  }
88  }
89 }
90 node {
91  name: "output"
92  op: "Gather"
93  input: "input0"
94  input: "input1"
95  attr {
96  key: "Tindices"
97  value {
98  type: DT_INT32
99  }
100  }
101  attr {
102  key: "Tparams"
103  value {
104  type: DT_FLOAT
105  }
106  }
107 }
108  )");
109  Setup({ { "input0", inputShape0 },
110  { "input1", inputShape1 } },
111  { "output" });
112 
113  }
114 };
115 
116 
117 struct GatherFixture1DParams1DIndices : public GatherFixture
118 {
119  GatherFixture1DParams1DIndices() : GatherFixture(
120  { 4, 1, 1, 1 },
121  { 4, 0, 0, 0 },
122  { 0, 2, 1, 3 },
123  { 4 },
124  { 4 }) {}
125 };
126 
127 struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
128 {
129  GatherFixture1DParamsMultiDimIndices() : GatherFixture(
130  { 4, 1, 1 },
131  { 2, 2, 1, 1 },
132  { 0, 1, 1, 3 },
133  { 4 },
134  { 2, 2 }) {}
135 };
136 
137 struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
138 {
139  GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
140  { 5, 2, 1 },
141  { 2, 1, 4 },
142  { 1, 3, 0, 2 },
143  { 5, 2 },
144  { 2, 2 }) {}
145 };
146 
147 BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
148 {
149  RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
150 
151  { { "output", { 1, 3, 2, 4 } } });
152 }
153 
154 BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
155 {
156  RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
157 
158  { { "output", { 1, 2, 2, 4 } } });
159 }
160 
161 BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
162 {
163  RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
164 
165  { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
166 }
167 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
std::string ConvertInt32ToOctalString(int value)
Converts an int value into the Prototxt octal representation.
BOOST_AUTO_TEST_SUITE_END()
BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
Definition: Gather.cpp:147