From 1f3b49be73d1fadf06f20c912aa160a5ab53a6a8 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Tue, 8 Sep 2020 08:57:40 +0100 Subject: IVGCVSW-5197 Add support for 2nd input to ExpandDims of TfParser * ParseExpandDims did not support to pass the axis parameter as a second input tensor * Added related unit tests Signed-off-by: Jan Eilers Change-Id: I8217950f0b42beaf5b9eaebdcad04267e4443ba3 --- src/armnnTfParser/test/ExpandDims.cpp | 201 ++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) (limited to 'src/armnnTfParser/test') diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp index 57d472d41d..ad95641cd1 100644 --- a/src/armnnTfParser/test/ExpandDims.cpp +++ b/src/armnnTfParser/test/ExpandDims.cpp @@ -109,4 +109,205 @@ BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim) armnn::TensorShape({2, 1, 3, 5}))); } +struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture +{ + ExpandDimsAsInputFixture(const std::string& expandDim, + const bool wrongDataType = false, + const std::string& numElements = "1") + { + std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32"; + std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim); + + m_Prototext = R"( + node { + name: "a" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 4 + } + } + } + } + } + node { + name: "b" + op: "Const" + attr { + key: "dtype" + value { + type: )" + dataType + R"( + } + } + attr { + key: "value" + value { + tensor { + dtype: )" + dataType + R"( + tensor_shape { + dim { + size: )" + numElements + R"( + } + } + )" + val + R"( + } + } + } + } + node { + name: "ExpandDims" + op: "ExpandDims" + input: "a" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + } + versions { + producer: 134 + })"; + } +}; + +struct ExpandDimAsInput : ExpandDimsAsInputFixture +{ + ExpandDimAsInput() : ExpandDimsAsInputFixture("0") + { + Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }); + } +}; + + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() == + armnn::TensorShape({1, 1, 4}))); +} + +struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture +{ + ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input + // Axis parameter is of wrong data type (float instead of int32) + BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); +} + +struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture +{ + ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input + // Axis parameter is of wrong shape + BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException); +} + +struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture +{ + ExpandDimsAsNotConstInputFixture() + { + m_Prototext = R"( + node { + name: "a" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 4 + } + } + } + } + } + node { + name: "b" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } + } + node { + name: "ExpandDims" + op: "ExpandDims" + input: "a" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + } + versions { + producer: 134 + })"; + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture) +{ + // Axis parameter that describes which axis/dim should be expanded is passed as a second input. + // But is not a constant tensor --> not supported + BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), + armnn::ParseException); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1