// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include #include "armnnTfParser/ITfParser.hpp" #include "ParserPrototxtFixture.hpp" BOOST_AUTO_TEST_SUITE(TensorflowParser) struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture { ExpandDimsFixture(const std::string& expandDim) { m_Prototext = "node { \n" " name: \"graphInput\" \n" " op: \"Placeholder\" \n" " attr { \n" " key: \"dtype\" \n" " value { \n" " type: DT_FLOAT \n" " } \n" " } \n" " attr { \n" " key: \"shape\" \n" " value { \n" " shape { \n" " } \n" " } \n" " } \n" " } \n" "node { \n" " name: \"ExpandDims\" \n" " op: \"ExpandDims\" \n" " input: \"graphInput\" \n" " attr { \n" " key: \"T\" \n" " value { \n" " type: DT_FLOAT \n" " } \n" " } \n" " attr { \n" " key: \"Tdim\" \n" " value { \n"; m_Prototext += "i:" + expandDim; m_Prototext += " } \n" " } \n" "} \n"; SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims"); } }; struct ExpandZeroDim : ExpandDimsFixture { ExpandZeroDim() : ExpandDimsFixture("0") {} }; struct ExpandTwoDim : ExpandDimsFixture { ExpandTwoDim() : ExpandDimsFixture("2") {} }; struct ExpandThreeDim : ExpandDimsFixture { ExpandThreeDim() : ExpandDimsFixture("3") {} }; struct ExpandMinusOneDim : ExpandDimsFixture { ExpandMinusOneDim() : ExpandDimsFixture("-1") {} }; struct ExpandMinusThreeDim : ExpandDimsFixture { ExpandMinusThreeDim() : ExpandDimsFixture("-3") {} }; BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({1, 2, 3, 5}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 3, 1, 5}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 3, 5, 1}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 3, 5, 1}))); } BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim) { BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({2, 1, 3, 5}))); } struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture { ExpandDimsAsInputFixture(const std::string& expandDim, const bool wrongDataType = false, const std::string& numElements = "1") { std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32"; std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim); m_Prototext = R"( node { name: "a" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { dim { size: 1 } dim { size: 4 } } } } } node { name: "b" op: "Const" attr { key: "dtype" value { type: )" + dataType + R"( } } attr { key: "value" value { tensor { dtype: )" + dataType + R"( tensor_shape { dim { size: )" + numElements + R"( } } )" + val + R"( } } } } node { name: "ExpandDims" op: "ExpandDims" input: "a" input: "b" attr { key: "T" value { type: DT_FLOAT } } attr { key: "Tdim" value { type: DT_INT32 } } } versions { producer: 134 })"; } }; struct ExpandDimAsInput : ExpandDimsAsInputFixture { ExpandDimAsInput() : ExpandDimsAsInputFixture("0") { Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }); } }; BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput) { // Axis parameter that describes which axis/dim should be expanded is passed as a second input BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == armnn::TensorShape({1, 1, 4}))); } struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture { ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {} }; BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType) { // Axis parameter that describes which axis/dim should be expanded is passed as a second input // Axis parameter is of wrong data type (float instead of int32) BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); } struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture { ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {} }; BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape) { // Axis parameter that describes which axis/dim should be expanded is passed as a second input // Axis parameter is of wrong shape BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); } struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture { ExpandDimsAsNotConstInputFixture() { m_Prototext = R"( node { name: "a" op: "Placeholder" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "shape" value { shape { dim { size: 1 } dim { size: 4 } } } } } node { name: "b" op: "Placeholder" attr { key: "dtype" value { type: DT_INT32 } } attr { key: "shape" value { shape { dim { size: 1 } } } } } node { name: "ExpandDims" op: "ExpandDims" input: "a" input: "b" attr { key: "T" value { type: DT_FLOAT } } attr { key: "Tdim" value { type: DT_INT32 } } } versions { producer: 134 })"; } }; BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture) { // Axis parameter that describes which axis/dim should be expanded is passed as a second input. // But is not a constant tensor --> not supported BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); } BOOST_AUTO_TEST_SUITE_END()