ArmNN
 21.02
BroadcastForAdd.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <boost/test/unit_test.hpp>
9 // This is a special case for add, which supports broadcasting.
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11 
12 struct BroadcastForAddFixtureSlot1 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13 {
14  BroadcastForAddFixtureSlot1()
15  {
16  m_Prototext = R"(
17  node {
18  name: "graphInput"
19  op: "Placeholder"
20  attr {
21  key: "dtype"
22  value {
23  type: DT_FLOAT
24  }
25  }
26  attr {
27  key: "shape"
28  value {
29  shape {
30  }
31  }
32  }
33  }
34  node {
35  name: "Const_1"
36  op: "Const"
37  attr {
38  key: "dtype"
39  value {
40  type: DT_FLOAT
41  }
42  }
43  attr {
44  key: "value"
45  value {
46  tensor {
47  dtype: DT_FLOAT
48  tensor_shape {
49  }
50  float_val: 4.0
51  float_val: 5.0
52  }
53  }
54  }
55  }
56  node {
57  name: "Add"
58  op: "Add"
59  input: "graphInput"
60  input: "Const_1"
61  attr {
62  key: "T"
63  value {
64  type: DT_FLOAT
65  }
66  }
67  }
68  )";
69 
70  SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add");
71  }
72 };
73 
74 struct BroadcastForAddFixtureSlot0 : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
75 {
76  BroadcastForAddFixtureSlot0()
77  {
78  m_Prototext = R"(
79  node {
80  name: "graphInput"
81  op: "Placeholder"
82  attr {
83  key: "dtype"
84  value {
85  type: DT_FLOAT
86  }
87  }
88  attr {
89  key: "shape"
90  value {
91  shape {
92  }
93  }
94  }
95  }
96  node {
97  name: "Const_1"
98  op: "Const"
99  attr {
100  key: "dtype"
101  value {
102  type: DT_FLOAT
103  }
104  }
105  attr {
106  key: "value"
107  value {
108  tensor {
109  dtype: DT_FLOAT
110  tensor_shape {
111  }
112  float_val: 4.0
113  float_val: 5.0
114  }
115  }
116  }
117  }
118  node {
119  name: "Add"
120  op: "Add"
121  input: "Const_1"
122  input: "graphInput"
123  attr {
124  key: "T"
125  value {
126  type: DT_FLOAT
127  }
128  }
129  }
130  )";
131 
132  SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add");
133  }
134 };
135 
136 
137 BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition1, BroadcastForAddFixtureSlot1)
138 {
139  RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 });
140 }
141 
142 BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition0, BroadcastForAddFixtureSlot0)
143 {
144  RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 });
145 }
146 
147 
148 
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition1, BroadcastForAddFixtureSlot1)
BOOST_AUTO_TEST_SUITE_END()
void SetupSingleInputSingleOutput(const std::string &inputName, const std::string &outputName)
Parses and loads the network defined by the m_Prototext string.