aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/ExpandDims.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/ExpandDims.cpp')
-rw-r--r--src/armnnTfParser/test/ExpandDims.cpp201
1 files changed, 201 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp
index 57d472d41d..ad95641cd1 100644
--- a/src/armnnTfParser/test/ExpandDims.cpp
+++ b/src/armnnTfParser/test/ExpandDims.cpp
@@ -109,4 +109,205 @@ BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
armnn::TensorShape({2, 1, 3, 5})));
}
+struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsAsInputFixture(const std::string& expandDim,
+ const bool wrongDataType = false,
+ const std::string& numElements = "1")
+ {
+ std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
+ std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
+
+ m_Prototext = R"(
+ node {
+ name: "a"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "b"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: )" + dataType + R"(
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: )" + dataType + R"(
+ tensor_shape {
+ dim {
+ size: )" + numElements + R"(
+ }
+ }
+ )" + val + R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "ExpandDims"
+ op: "ExpandDims"
+ input: "a"
+ input: "b"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ versions {
+ producer: 134
+ })";
+ }
+};
+
+struct ExpandDimAsInput : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
+ {
+ Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
+ }
+};
+
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({1, 1, 4})));
+}
+
+struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ // Axis parameter is of wrong data type (float instead of int32)
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ // Axis parameter is of wrong shape
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsAsNotConstInputFixture()
+ {
+ m_Prototext = R"(
+ node {
+ name: "a"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "b"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "ExpandDims"
+ op: "ExpandDims"
+ input: "a"
+ input: "b"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ versions {
+ producer: 134
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
+ // But is not a constant tensor --> not supported
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
+ armnn::ParseException);
+}
+
BOOST_AUTO_TEST_SUITE_END()