aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorjimfly01 <jim.flynn@arm.com>2018-12-19 13:14:46 +0000
committerjimfly01 <jim.flynn@arm.com>2018-12-20 14:19:25 +0000
commit84c70e65a193aa5faa959d305af82783fa8f78b5 (patch)
tree4e22d949b863c21c6bb7aa1bb92e8465a337dd79 /src/armnnTfParser/TfParser.cpp
parent4fa0916386e720e254bee6b9fd1576e90ba6a42f (diff)
downloadarmnn-84c70e65a193aa5faa959d305af82783fa8f78b5.tar.gz
IVGCVSW-2367 Add Equal Operator to TfParser
* Unit tests in Equal.cpp * Fixed error in Network::AddEqualLayer * Refactored TfParser::Minimum/Equal to get rid of duplicate code Change-Id: I0ed6f888eb391c995b88be20dc0c1b916dd14c3c
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp48
1 files changed, 40 insertions, 8 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 74742a97b3..b646437f36 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -353,6 +353,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "AvgPool", &TfParser::ParseAvgPool },
{ "Maximum", &TfParser::ParseMaximum },
{ "Minimum", &TfParser::ParseMinimum },
+ { "Equal", &TfParser::ParseEqual },
{ "Pad", &TfParser::ParsePad },
{ "Sub", &TfParser::ParseSub },
};
@@ -1530,8 +1531,8 @@ ParsedTfOperationPtr TfParser::ParseMaximum(const tensorflow::NodeDef& nodeDef,
}
}
-ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
- const tensorflow::GraphDef& graphDef)
+std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> TfParser::ProcessElementwiseInputSlots(
+ const tensorflow::NodeDef& nodeDef, const std::string& layerName)
{
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
@@ -1555,15 +1556,22 @@ ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
else
{
throw ParseException(
- boost::str(
- boost::format("Unsupported broadcast configuration for Minimum operation %1% %2%")
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ boost::str(
+ boost::format("Unsupported broadcast configuration for %1% operation %2% %3%")
+ % layerName
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
}
+ return {input0Slot, input1Slot};
+}
- IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str());
-
+ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
+ IOutputSlot* input0Slot,
+ IOutputSlot* input1Slot,
+ IConnectableLayer* const layer,
+ const tensorflow::NodeDef& nodeDef)
+{
input0Slot->Connect(layer->GetInputSlot(0));
input1Slot->Connect(layer->GetInputSlot(1));
@@ -1584,6 +1592,30 @@ ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Equal");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddEqualLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
+ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Minimum");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
ParsedTfOperationPtr TfParser::ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
{
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);