diff options
Diffstat (limited to 'src/armnnTfParser/test/ExpandDims.cpp')
-rw-r--r-- | src/armnnTfParser/test/ExpandDims.cpp | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp index 57d472d41d..ad95641cd1 100644 --- a/src/armnnTfParser/test/ExpandDims.cpp +++ b/src/armnnTfParser/test/ExpandDims.cpp @@ -109,4 +109,205 @@ BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim) armnn::TensorShape({2, 1, 3, 5}))); } +struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + ExpandDimsAsInputFixture(const std::string& expandDim, + const bool wrongDataType = false, + const std::string& numElements = "1") + { + std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32"; + std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim); + + m_Prototext = R"( + node { + name: "a" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 4 + } + } + } + } + } + node { + name: "b" + op: "Const" + attr { + key: "dtype" + value { + type: )" + dataType + R"( + } + } + attr { + key: "value" + value { + tensor { + dtype: )" + dataType + R"( + tensor_shape { + dim { + size: )" + numElements + R"( + } + } + )" + val + R"( + } + } + } + } + node { + name: "ExpandDims" + op: "ExpandDims" + input: "a" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + } + versions { + producer: 134 + })"; + } +}; + +struct ExpandDimAsInput : ExpandDimsAsInputFixture +{ + ExpandDimAsInput() : ExpandDimsAsInputFixture("0") + { + Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }); + } +}; + + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == + armnn::TensorShape({1, 1, 4}))); +} + +struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture +{ + ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input + // Axis parameter is of wrong data type (float instead of int32) + BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); +} + +struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture +{ + ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input + // Axis parameter is of wrong shape + BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); +} + +struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + ExpandDimsAsNotConstInputFixture() + { + m_Prototext = R"( + node { + name: "a" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 4 + } + } + } + } + } + node { + name: "b" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } + } + node { + name: "ExpandDims" + op: "ExpandDims" + input: "a" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + } + versions { + producer: 134 + })"; + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input. + // But is not a constant tensor --> not supported + BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), + armnn::ParseException); +} + BOOST_AUTO_TEST_SUITE_END() |