aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/FullyConnected.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/FullyConnected.cpp')
-rw-r--r--src/armnnTfParser/test/FullyConnected.cpp38
1 files changed, 19 insertions, 19 deletions
diff --git a/src/armnnTfParser/test/FullyConnected.cpp b/src/armnnTfParser/test/FullyConnected.cpp
index 2a7b4951b7..e7f040e784 100644
--- a/src/armnnTfParser/test/FullyConnected.cpp
+++ b/src/armnnTfParser/test/FullyConnected.cpp
@@ -14,15 +14,15 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser)
// In Tensorflow fully connected layers are expressed as a MatMul followed by an Add.
// The TfParser must detect this case and convert them to a FullyConnected layer.
-struct FullyConnectedFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+struct FullyConnectedFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
FullyConnectedFixture()
{
- // input = tf.placeholder(tf.float32, [1, 1], "input")
- // weights = tf.constant([2], tf.float32, [1, 1])
- // matmul = tf.matmul(input, weights)
- // bias = tf.constant([1], tf.float32)
- // output = tf.add(matmul, bias, name="output")
+ // Input = tf.placeholder(tf.float32, [1, 1], "input")
+ // Weights = tf.constant([2], tf.float32, [1, 1])
+ // Matmul = tf.matmul(input, weights)
+ // Bias = tf.constant([1], tf.float32)
+ // Output = tf.add(matmul, bias, name="output")
m_Prototext = R"(
node {
name: "input"
@@ -153,7 +153,7 @@ BOOST_FIXTURE_TEST_CASE(FullyConnected, FullyConnectedFixture)
// C-- A A -- C
// \ /
// A
-struct MatMulUsedInTwoFcFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+struct MatMulUsedInTwoFcFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
MatMulUsedInTwoFcFixture()
{
@@ -326,7 +326,7 @@ BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFc, MatMulUsedInTwoFcFixture)
RunTest<1>({ 3 }, { 32 });
// Ideally we would check here that the armnn network has 5 layers:
// Input, 2 x FullyConnected (biased), Add and Output.
- // This would make sure the parser hasn't incorrectly added some unconnected layers corresponding to the MatMul
+ // This would make sure the parser hasn't incorrectly added some unconnected layers corresponding to the MatMul.
}
// Similar to MatMulUsedInTwoFc, but this time the Adds are 'staggered' (see diagram), which means that only one
@@ -338,16 +338,16 @@ BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFc, MatMulUsedInTwoFcFixture)
// C2 -- A |
// \ /
// A
-struct MatMulUsedInTwoFcStaggeredFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+struct MatMulUsedInTwoFcStaggeredFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
MatMulUsedInTwoFcStaggeredFixture()
{
- // input = tf.placeholder(tf.float32, shape=[1,1], name = "input")
- // const1 = tf.constant([17], tf.float32, [1,1])
- // mul = tf.matmul(input, const1)
- // const2 = tf.constant([7], tf.float32, [1])
- // fc = tf.add(mul, const2)
- // output = tf.add(mul, fc, name="output")
+ // Input = tf.placeholder(tf.float32, shape=[1,1], name = "input")
+ // Const1 = tf.constant([17], tf.float32, [1,1])
+ // Mul = tf.matmul(input, const1)
+ // Monst2 = tf.constant([7], tf.float32, [1])
+ // Fc = tf.add(mul, const2)
+ // Output = tf.add(mul, fc, name="output")
m_Prototext = R"(
node {
name: "input"
@@ -484,13 +484,13 @@ BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFcStaggered, MatMulUsedInTwoFcStaggeredFi
}
// A MatMul in isolation, not connected to an add. Should result in a non-biased FullyConnected layer.
-struct MatMulFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+struct MatMulFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
MatMulFixture()
{
- // input = tf.placeholder(tf.float32, shape = [1, 1], name = "input")
- // const = tf.constant([17], tf.float32, [1, 1])
- // output = tf.matmul(input, const, name = "output")
+ // Input = tf.placeholder(tf.float32, shape = [1, 1], name = "input")
+ // Const = tf.constant([17], tf.float32, [1, 1])
+ // Output = tf.matmul(input, const, name = "output")
m_Prototext = R"(
node {
name: "input"