diff options
author | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
---|---|---|
committer | surmeh01 <surabhi.mehta@arm.com> | 2018-03-29 16:29:27 +0100 |
commit | bceff2fb3fc68bb0aa88b886900c34b77340c826 (patch) | |
tree | d867d3e090d58d3012dfbbac456e9ea8c7f789bc /src/armnnTfParser/test/Squeeze.cpp | |
parent | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (diff) | |
download | armnn-bceff2fb3fc68bb0aa88b886900c34b77340c826.tar.gz |
Release 18.03
Diffstat (limited to 'src/armnnTfParser/test/Squeeze.cpp')
-rw-r--r-- | src/armnnTfParser/test/Squeeze.cpp | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Squeeze.cpp b/src/armnnTfParser/test/Squeeze.cpp new file mode 100644 index 0000000000..d2d7d49494 --- /dev/null +++ b/src/armnnTfParser/test/Squeeze.cpp @@ -0,0 +1,108 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include <boost/test/unit_test.hpp> +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + +template <bool withDimZero, bool withDimOne> +struct SqueezeFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + SqueezeFixture() + { + m_Prototext = + "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + " } \n" + "node { \n" + " name: \"Squeeze\" \n" + " op: \"Squeeze\" \n" + " input: \"graphInput\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"squeeze_dims\" \n" + " value { \n" + " list {\n"; + + if (withDimZero) + { + m_Prototext += "i:0\n"; + } + + if (withDimOne) + { + m_Prototext += "i:1\n"; + } + + m_Prototext += + " } \n" + " } \n" + " } \n" + "} \n"; + + SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze"); + } +}; + +typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture; +typedef SqueezeFixture<true, false> ExplicitDimensionZeroSqueezeFixture; +typedef SqueezeFixture<false, true> ExplicitDimensionOneSqueezeFixture; +typedef SqueezeFixture<true, true> ExplicitDimensionsSqueezeFixture; + +BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({2,2}))); + RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({1,2,2}))); + RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({1,2,2}))); + RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({2,2}))); + RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_AUTO_TEST_SUITE_END() |