aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/BiasAdd.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/BiasAdd.cpp')
-rw-r--r--src/armnnTfParser/test/BiasAdd.cpp104
1 files changed, 104 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/BiasAdd.cpp b/src/armnnTfParser/test/BiasAdd.cpp
new file mode 100644
index 0000000000..e29aeb1057
--- /dev/null
+++ b/src/armnnTfParser/test/BiasAdd.cpp
@@ -0,0 +1,104 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct BiasAddFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit BiasAddFixture(const std::string& dataFormat)
+ {
+ m_Prototext = R"(
+node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "bias"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1
+ float_val: 2
+ float_val: 3
+ }
+ }
+ }
+}
+node {
+ name: "biasAdd"
+ op : "BiasAdd"
+ input: "graphInput"
+ input: "bias"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: ")" + dataFormat + R"("
+ }
+ }
+}
+)";
+
+ SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd");
+ }
+};
+
+struct BiasAddFixtureNCHW : BiasAddFixture
+{
+ BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {}
+};
+
+struct BiasAddFixtureNHWC : BiasAddFixture
+{
+ BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW)
+{
+ RunTest<4>(std::vector<float>(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC)
+{
+ RunTest<4>(std::vector<float>(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 });
+}
+
+BOOST_AUTO_TEST_SUITE_END()