ArmNN
 21.02
DeserializeGather.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 
10 #include <string>
11 
12 BOOST_AUTO_TEST_SUITE(Deserializer)
13 
14 struct GatherFixture : public ParserFlatbuffersSerializeFixture
15 {
16  explicit GatherFixture(const std::string& inputShape,
17  const std::string& indicesShape,
18  const std::string& input1Content,
19  const std::string& outputShape,
20  const std::string& axis,
21  const std::string dataType,
22  const std::string constDataType)
23  {
24  m_JsonString = R"(
25  {
26  inputIds: [0],
27  outputIds: [3],
28  layers: [
29  {
30  layer_type: "InputLayer",
31  layer: {
32  base: {
33  layerBindingId: 0,
34  base: {
35  index: 0,
36  layerName: "InputLayer",
37  layerType: "Input",
38  inputSlots: [{
39  index: 0,
40  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
41  }],
42  outputSlots: [ {
43  index: 0,
44  tensorInfo: {
45  dimensions: )" + inputShape + R"(,
46  dataType: )" + dataType + R"(
47  }}]
48  }
49  }}},
50  {
51  layer_type: "ConstantLayer",
52  layer: {
53  base: {
54  index:1,
55  layerName: "ConstantLayer",
56  layerType: "Constant",
57  outputSlots: [ {
58  index: 0,
59  tensorInfo: {
60  dimensions: )" + indicesShape + R"(,
61  dataType: "Signed32",
62  },
63  }],
64  },
65  input: {
66  info: {
67  dimensions: )" + indicesShape + R"(,
68  dataType: )" + dataType + R"(
69  },
70  data_type: )" + constDataType + R"(,
71  data: {
72  data: )" + input1Content + R"(,
73  } }
74  },},
75  {
76  layer_type: "GatherLayer",
77  layer: {
78  base: {
79  index: 2,
80  layerName: "GatherLayer",
81  layerType: "Gather",
82  inputSlots: [
83  {
84  index: 0,
85  connection: {sourceLayerIndex:0, outputSlotIndex:0 },
86  },
87  {
88  index: 1,
89  connection: {sourceLayerIndex:1, outputSlotIndex:0 }
90  }],
91  outputSlots: [ {
92  index: 0,
93  tensorInfo: {
94  dimensions: )" + outputShape + R"(,
95  dataType: )" + dataType + R"(
96 
97  }}]},
98  descriptor: {
99  axis: )" + axis + R"(
100  }
101  }},
102  {
103  layer_type: "OutputLayer",
104  layer: {
105  base:{
106  layerBindingId: 0,
107  base: {
108  index: 3,
109  layerName: "OutputLayer",
110  layerType: "Output",
111  inputSlots: [{
112  index: 0,
113  connection: {sourceLayerIndex:2, outputSlotIndex:0 },
114  }],
115  outputSlots: [ {
116  index: 0,
117  tensorInfo: {
118  dimensions: )" + outputShape + R"(,
119  dataType: )" + dataType + R"(
120  },
121  }],
122  }}},
123  }]
124  } )";
125 
126  Setup();
127  }
128 };
129 
130 struct SimpleGatherFixtureFloat32 : GatherFixture
131 {
132  SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
133  "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
134 };
135 
136 BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
137 {
138  RunTest<4, armnn::DataType::Float32>(0,
139  {{"InputLayer", { 1, 2, 3,
140  4, 5, 6,
141  7, 8, 9,
142  10, 11, 12,
143  13, 14, 15,
144  16, 17, 18 }}},
145  {{"OutputLayer", { 7, 8, 9,
146  10, 11, 12,
147  13, 14, 15,
148  16, 17, 18,
149  7, 8, 9,
150  10, 11, 12,
151  13, 14, 15,
152  16, 17, 18,
153  7, 8, 9,
154  10, 11, 12,
155  1, 2, 3,
156  4, 5, 6 }}});
157 }
158 
160 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
BOOST_AUTO_TEST_SUITE_END()