// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include #include "armnnTfParser/ITfParser.hpp" #include "ParserPrototxtFixture.hpp" BOOST_AUTO_TEST_SUITE(TensorflowParser) struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture { ExpandDimsFixture(const std::string& expandDim) { m_Prototext = "node { \n" " name: \"graphInput\" \n" " op: \"Placeholder\" \n" " attr { \n" " key: \"dtype\" \n" " value { \n" " type: DT_FLOAT \n" " } \n" " } \n" " attr { \n" " key: \"shape\" \n" " value { \n" " shape { \n" " } \n" " } \n" " } \n" " } \n" "node { \n" " name: \"ExpandDims\" \n" " op: \"ExpandDims\" \n" " input: \"graphInput\" \n" " attr { \n" " key: \"T\" \n" " value { \n" " type: DT_FLOAT \n" " } \n" " } \n" " attr { \n" " key: \"Tdim\" \n" " value { \n"; m_Prototext += "i:" + expandDim; m_Prototext += " } \n" " } \n" "} \n"; SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims"); } }; struct ExpandZeroDim : ExpandDimsFixture { ExpandZeroDim() : ExpandDimsFixture("0") {} }; struct ExpandTwoDim : ExpandDimsFixture { ExpandTwoDim() : ExpandDimsFixture("2") {} }; struct ExpandThreeDim : ExpandDimsFixture { ExpandThreeDim() : ExpandDimsFixture("3") {} }; struct ExpandMinusOneDim : ExpandDimsFixture { ExpandMinusOneDim() : ExpandDimsFixture("-1") {} }; struct ExpandMinusThreeDim : ExpandDimsFixture { ExpandMinusThreeDim() : ExpandDimsFixture("-3") {} }; BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({1, 2, 3, 5}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 3, 1, 5}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 3, 5, 1}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 3, 5, 1}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 1, 3, 5}))); } BOOST_AUTO_TEST_SUITE_END()