aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser
diff options
context:
space:
mode:
authorjimfly01 <jim.flynn@arm.com>2018-12-19 13:14:46 +0000
committerjimfly01 <jim.flynn@arm.com>2018-12-20 14:19:25 +0000
commit84c70e65a193aa5faa959d305af82783fa8f78b5 (patch)
tree4e22d949b863c21c6bb7aa1bb92e8465a337dd79 /src/armnnTfParser
parent4fa0916386e720e254bee6b9fd1576e90ba6a42f (diff)
downloadarmnn-84c70e65a193aa5faa959d305af82783fa8f78b5.tar.gz
IVGCVSW-2367 Add Equal Operator to TfParser
* Unit tests in Equal.cpp * Fixed error in Network::AddEqualLayer * Refactored TfParser::Minimum/Equal to get rid of duplicate code Change-Id: I0ed6f888eb391c995b88be20dc0c1b916dd14c3c
Diffstat (limited to 'src/armnnTfParser')
-rw-r--r--src/armnnTfParser/TfParser.cpp48
-rw-r--r--src/armnnTfParser/TfParser.hpp11
-rw-r--r--src/armnnTfParser/test/Equal.cpp139
3 files changed, 190 insertions, 8 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 74742a97b3..b646437f36 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -353,6 +353,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "AvgPool", &TfParser::ParseAvgPool },
{ "Maximum", &TfParser::ParseMaximum },
{ "Minimum", &TfParser::ParseMinimum },
+ { "Equal", &TfParser::ParseEqual },
{ "Pad", &TfParser::ParsePad },
{ "Sub", &TfParser::ParseSub },
};
@@ -1530,8 +1531,8 @@ ParsedTfOperationPtr TfParser::ParseMaximum(const tensorflow::NodeDef& nodeDef,
}
}
-ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
- const tensorflow::GraphDef& graphDef)
+std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> TfParser::ProcessElementwiseInputSlots(
+ const tensorflow::NodeDef& nodeDef, const std::string& layerName)
{
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
@@ -1555,15 +1556,22 @@ ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
else
{
throw ParseException(
- boost::str(
- boost::format("Unsupported broadcast configuration for Minimum operation %1% %2%")
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ boost::str(
+ boost::format("Unsupported broadcast configuration for %1% operation %2% %3%")
+ % layerName
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
}
+ return {input0Slot, input1Slot};
+}
- IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str());
-
+ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
+ IOutputSlot* input0Slot,
+ IOutputSlot* input1Slot,
+ IConnectableLayer* const layer,
+ const tensorflow::NodeDef& nodeDef)
+{
input0Slot->Connect(layer->GetInputSlot(0));
input1Slot->Connect(layer->GetInputSlot(1));
@@ -1584,6 +1592,30 @@ ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Equal");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddEqualLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
+ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Minimum");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
ParsedTfOperationPtr TfParser::ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
{
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 5ca867c0f7..3aba60cc0a 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -13,6 +13,7 @@
#include <map>
#include <memory>
#include <unordered_map>
+#include <utility>
#include <vector>
namespace armnn
@@ -154,6 +155,7 @@ private:
ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef,
armnn::PoolingAlgorithm pooltype);
+ ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
@@ -175,6 +177,15 @@ private:
armnn::IOutputSlot** outputOfLeakyRelu,
armnn::ActivationDescriptor & desc);
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> ProcessElementwiseInputSlots(
+ const tensorflow::NodeDef& nodeDef, const std::string& layerName);
+
+ ParsedTfOperationPtr ProcessElementwiseLayer(
+ armnn::IOutputSlot* input0Slot,
+ armnn::IOutputSlot* input1Slot,
+ armnn::IConnectableLayer* const layer,
+ const tensorflow::NodeDef& nodeDef);
+
static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
const char* bindingPointDesc,
const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
diff --git a/src/armnnTfParser/test/Equal.cpp b/src/armnnTfParser/test/Equal.cpp
new file mode 100644
index 0000000000..43a1c6abb5
--- /dev/null
+++ b/src/armnnTfParser/test/Equal.cpp
@@ -0,0 +1,139 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+ struct EqualFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+ {
+ EqualFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Equal"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ }
+ };
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualUnsupportedBroadcast, EqualFixture)
+{
+ BOOST_REQUIRE_THROW(Setup({ { "input0", {2, 3} },
+ { "input1", {1, 2, 2, 3} } },
+ { "output" }),
+ armnn::ParseException);
+}
+
+struct EqualFixtureAutoSetup : public EqualFixture
+{
+ EqualFixtureAutoSetup(const armnn::TensorShape& input0Shape,
+ const armnn::TensorShape& input1Shape)
+ : EqualFixture()
+ {
+ Setup({ { "input0", input0Shape },
+ { "input1", input1Shape } },
+ { "output" });
+ }
+};
+
+struct EqualTwoByTwo : public EqualFixtureAutoSetup
+{
+ EqualTwoByTwo() : EqualFixtureAutoSetup({2,2}, {2,2}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualTwoByTwo, EqualTwoByTwo)
+{
+ RunTest<2>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
+ { "input1", { 1.0f, 5.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.0f, 0.0f, 0.0f, 1.0f } } });
+}
+
+struct EqualBroadcast1DAnd4D : public EqualFixtureAutoSetup
+{
+ EqualBroadcast1DAnd4D() : EqualFixtureAutoSetup({1}, {1,1,2,2}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast1DToTwoByTwo, EqualBroadcast1DAnd4D)
+{
+ RunTest<4>({ { "input0", { 2.0f } },
+ { "input1", { 1.0f, 2.0f, 3.0f, 2.0f } } },
+ { { "output", { 0.0f, 1.0f, 0.0f, 1.0f } } });
+}
+
+struct EqualBroadcast4DAnd1D : public EqualFixtureAutoSetup
+{
+ EqualBroadcast4DAnd1D() : EqualFixtureAutoSetup({1,1,2,2}, {1}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualBroadcast4DAnd1D, EqualBroadcast4DAnd1D)
+{
+ RunTest<4>({ { "input0", { 1.0f, 2.0f, 3.0f, 2.0f } },
+ { "input1", { 3.0f } } },
+ { { "output", { 0.0f, 0.0f, 1.0f, 0.0f } } });
+}
+
+struct EqualMultiDimBroadcast : public EqualFixtureAutoSetup
+{
+ EqualMultiDimBroadcast() : EqualFixtureAutoSetup({1,1,2,1}, {1,2,1,3}) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseEqualMultiDimBroadcast, EqualMultiDimBroadcast)
+{
+ RunTest<4>({ { "input0", { 1.0f, 2.0f } },
+ { "input1", { 1.0f, 2.0f, 3.0f,
+ 3.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.0f, 0.0f, 0.0f,
+ 0.0f, 1.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f,
+ 0.0f, 1.0f, 1.0f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()