aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2020-09-08 08:57:40 +0100
committerKeithARM <keith.davis@arm.com>2020-09-10 09:23:30 +0000
commit1f3b49be73d1fadf06f20c912aa160a5ab53a6a8 (patch)
tree8bd9d08025f23c2054c21f1d9d76a29435fb6d74 /src/armnnTfParser/test
parent54940191dfe3a405dcc0fdf6516849082ae62cc7 (diff)
downloadarmnn-1f3b49be73d1fadf06f20c912aa160a5ab53a6a8.tar.gz
IVGCVSW-5197 Add support for 2nd input to ExpandDims of TfParser
* ParseExpandDims did not support to pass the axis parameter as a second input tensor * Added related unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I8217950f0b42beaf5b9eaebdcad04267e4443ba3
Diffstat (limited to 'src/armnnTfParser/test')
-rw-r--r--src/armnnTfParser/test/ExpandDims.cpp201
1 files changed, 201 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp
index 57d472d41d..ad95641cd1 100644
--- a/src/armnnTfParser/test/ExpandDims.cpp
+++ b/src/armnnTfParser/test/ExpandDims.cpp
@@ -109,4 +109,205 @@ BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
armnn::TensorShape({2, 1, 3, 5})));
}
+struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsAsInputFixture(const std::string& expandDim,
+ const bool wrongDataType = false,
+ const std::string& numElements = "1")
+ {
+ std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
+ std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
+
+ m_Prototext = R"(
+ node {
+ name: "a"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "b"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: )" + dataType + R"(
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: )" + dataType + R"(
+ tensor_shape {
+ dim {
+ size: )" + numElements + R"(
+ }
+ }
+ )" + val + R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "ExpandDims"
+ op: "ExpandDims"
+ input: "a"
+ input: "b"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ versions {
+ producer: 134
+ })";
+ }
+};
+
+struct ExpandDimAsInput : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
+ {
+ Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
+ }
+};
+
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({1, 1, 4})));
+}
+
+struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ // Axis parameter is of wrong data type (float instead of int32)
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ // Axis parameter is of wrong shape
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsAsNotConstInputFixture()
+ {
+ m_Prototext = R"(
+ node {
+ name: "a"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "b"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "ExpandDims"
+ op: "ExpandDims"
+ input: "a"
+ input: "b"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ versions {
+ producer: 134
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
+ // But is not a constant tensor --> not supported
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
+ armnn::ParseException);
+}
+
BOOST_AUTO_TEST_SUITE_END()