diff options
Diffstat (limited to 'src/armnnTfParser/test/Constant.cpp')
-rw-r--r-- | src/armnnTfParser/test/Constant.cpp | 321 |
1 files changed, 321 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Constant.cpp b/src/armnnTfParser/test/Constant.cpp new file mode 100644 index 0000000000..09587fc3d5 --- /dev/null +++ b/src/armnnTfParser/test/Constant.cpp @@ -0,0 +1,321 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include <boost/test/unit_test.hpp> + +#include "armnnTfParser/ITfParser.hpp" + +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +// Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most +// Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to +// armnn ConstLayers). +struct ConstantFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + ConstantFixture() + { + // input = tf.placeholder(tf.float32, name = "input") + // const = tf.constant([17], tf.float32, [1]) + // output = tf.add(input, const, name = "output") + m_Prototext = + R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } +} +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 17.0 + } + } + } +} +node { + name: "output" + op: "Add" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + SetupSingleInputSingleOutput({ 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture) +{ + RunTest<1>({1}, {18}); +} + + +// Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only +// a single armnn ConstLayer being created. +struct ConstantReusedFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + ConstantReusedFixture() + { + // const = tf.constant([17], tf.float32, [1]) + // output = tf.add(const, const, name = "output") + m_Prototext = + R"( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 17.0 + } + } + } +} +node { + name: "output" + op: "Add" + input: "Const" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + Setup({}, { "output" }); + } +}; + +BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture) +{ + RunTest<1>({}, { { "output", { 34 } } }); +} + +template <int ListSize> +struct ConstantValueListFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + ConstantValueListFixture() + { + m_Prototext = + R"( +node { + name: "output" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 3 + } + })"; + + double value = 0.75; + for (int i = 0; i < ListSize; i++, value += 0.25) + { + m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n"; + } + + m_Prototext += + R"( + } + } + } +} + )"; + Setup({}, { "output" }); + } +}; + +using ConstantSingleValueListFixture = ConstantValueListFixture<1>; +using ConstantMultipleValueListFixture = ConstantValueListFixture<4>; +using ConstantMaxValueListFixture = ConstantValueListFixture<6>; + +BOOST_FIXTURE_TEST_CASE(ConstantSingleValueList, ConstantSingleValueListFixture) +{ + RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } }); +} +BOOST_FIXTURE_TEST_CASE(ConstantMultipleValueList, ConstantMultipleValueListFixture) +{ + RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f, 1.5f, 1.5f } } }); +} +BOOST_FIXTURE_TEST_CASE(ConstantMaxValueList, ConstantMaxValueListFixture) +{ + RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } }); +} + +template <bool WithShape, bool WithContent, bool WithValueList> +struct ConstantCreateFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + ConstantCreateFixture() + { + m_Prototext = + R"( +node { + name: "output" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + )"; + + if (WithShape) + { + m_Prototext += + R"( +tensor_shape { + dim { + size: 2 + } + dim { + size: 2 + } +} + )"; + } + else + { + m_Prototext += + R"( +tensor_shape { +} + )"; + } + + if (WithContent) + { + m_Prototext += + R"( +tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + )"; + } + + if (WithValueList) + { + m_Prototext += + R"( +float_val: 1.0 +float_val: 1.0 +float_val: 1.0 +float_val: 1.0 +float_val: 1.0 + )"; + } + + m_Prototext += + R"( + } + } + } +} + )"; + } +}; + +using ConstantCreateNoValueListFixture = ConstantCreateFixture<true, false, true>; +using ConstantCreateNoValueList2Fixture = ConstantCreateFixture<true, false, false>; +using ConstantCreateNoContentFixture = ConstantCreateFixture<true, true, false>; +using ConstantCreateNoContent2Fixture = ConstantCreateFixture<true, false, false>; +using ConstantCreateNoShapeFixture = ConstantCreateFixture<false, false, false>; +using ConstantCreateNoShape2Fixture = ConstantCreateFixture<false, true, false>; +using ConstantCreateNoShape3Fixture = ConstantCreateFixture<false, false, true>; + +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList, ConstantCreateNoValueListFixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList2, ConstantCreateNoValueList2Fixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidContent, ConstantCreateNoContentFixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidShape, ConstantCreateNoShapeFixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape2, ConstantCreateNoShape2Fixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape3, ConstantCreateNoShape3Fixture) +{ + Setup({}, { "output" }); + RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() |