aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser
diff options
context:
space:
mode:
authorjimfly01 <jim.flynn@arm.com>2018-12-18 16:24:51 +0000
committerJim Flynn Arm <jim.flynn@arm.com>2018-12-20 17:37:20 +0000
commita06bf31afabfb84e60740ea3219406ab13c8e6a6 (patch)
tree45c5b05799be32ec3c70ca3d25b3643a8baf3f0b /src/armnnTfParser
parentf446432f4c21a64ffb92552c5e1906194fb98558 (diff)
downloadarmnn-a06bf31afabfb84e60740ea3219406ab13c8e6a6.tar.gz
IVGCVSW-2380 Add Greater operator to TfParser
* Unit tests in Greater.cpp Change-Id: Ifb3e4c33be2d6235e33889bb63e6abd78bd7d8b6
Diffstat (limited to 'src/armnnTfParser')
-rw-r--r--src/armnnTfParser/TfParser.cpp13
-rw-r--r--src/armnnTfParser/TfParser.hpp1
-rw-r--r--src/armnnTfParser/test/Greater.cpp139
3 files changed, 153 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b646437f36..45c039bb15 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -333,6 +333,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
{ "ExpandDims", &TfParser::ParseExpandDims },
{ "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
+ { "Greater", &TfParser::ParseGreater},
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
{ "MatMul", &TfParser::ParseMatMul },
@@ -1592,6 +1593,18 @@ ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseGreater(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Greater");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddGreaterLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 3aba60cc0a..55797471e2 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -158,6 +158,7 @@ private:
ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
diff --git a/src/armnnTfParser/test/Greater.cpp b/src/armnnTfParser/test/Greater.cpp
new file mode 100644
index 0000000000..f11c199599
--- /dev/null
+++ b/src/armnnTfParser/test/Greater.cpp
@@ -0,0 +1,139 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct GreaterFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ GreaterFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Greater"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGreaterUnsupportedBroadcast, GreaterFixture)
+{
+ BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} },
+ { "input1", {1, 2, 2, 3} } },
+ { "output" }),
+ armnn::ParseException);
+}
+
+struct GreaterFixtureAutoSetup : public GreaterFixture
+{
+ GreaterFixtureAutoSetup(const armnn::TensorShape& input0Shape,
+ const armnn::TensorShape& input1Shape)
+ : GreaterFixture()
+ {
+ Setup({ { "input0", input0Shape },
+ { "input1", input1Shape } },
+ { "output" });
+ }
+};
+
+struct GreaterFixtureTwoByTwo : public GreaterFixtureAutoSetup
+{
+ GreaterFixtureTwoByTwo() : GreaterFixtureAutoSetup({2, 2}, {2, 2}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGreaterTwoByTwo, GreaterFixtureTwoByTwo)
+{
+ RunTest<2>({ { "input0", { 1.0f, 2.0f, 3.0f, 4.0f} },
+ { "input1", { 1.0f, 5.0f, 2.0f, 2.0f} } },
+ { { "output", { 0.0f, 0.0f, 1.0f, 1.0f} } });
+}
+
+struct GreaterBroadcast1DAnd4D : public GreaterFixtureAutoSetup
+{
+ GreaterBroadcast1DAnd4D() : GreaterFixtureAutoSetup({1}, {1,1,2,2}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGreaterBroadcast1DToTwoByTwo, GreaterBroadcast1DAnd4D)
+{
+ RunTest<4>({ { "input0", { 2.0f } },
+ { "input1", { 1.0f, 2.0f, 3.0f, 2.0f } } },
+ { { "output", { 1.0f, 0.0f, 0.0f, 0.0f } } });
+}
+
+struct GreaterBroadcast4DAnd1D : public GreaterFixtureAutoSetup
+{
+ GreaterBroadcast4DAnd1D() : GreaterFixtureAutoSetup({1,1,2,2}, {1}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGreaterBroadcast4DAnd1D, GreaterBroadcast4DAnd1D)
+{
+ RunTest<4>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
+ { "input1", { 3.0f } } },
+ { { "output", { 0.0f, 0.0f, 0.0f, 0.0f } } });
+}
+
+struct GreaterMultiDimBroadcast : public GreaterFixtureAutoSetup
+{
+ GreaterMultiDimBroadcast() : GreaterFixtureAutoSetup({1,1,2,1}, {1,2,1,3}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGreaterMultiDimBroadcast, GreaterMultiDimBroadcast)
+{
+ RunTest<4>({ { "input0", { 1.0f, 2.0f } },
+ { "input1", { 1.0f, 2.0f, 3.0f,
+ 3.0f, 2.0f, 2.0f } } },
+ { { "output", { 0.0f, 0.0f, 0.0f,
+ 1.0f, 0.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()