diff options
author | jimfly01 <jim.flynn@arm.com> | 2018-12-19 13:14:46 +0000 |
---|---|---|
committer | jimfly01 <jim.flynn@arm.com> | 2018-12-20 14:19:25 +0000 |
commit | 84c70e65a193aa5faa959d305af82783fa8f78b5 (patch) | |
tree | 4e22d949b863c21c6bb7aa1bb92e8465a337dd79 /src/armnnTfParser/TfParser.cpp | |
parent | 4fa0916386e720e254bee6b9fd1576e90ba6a42f (diff) | |
download | armnn-84c70e65a193aa5faa959d305af82783fa8f78b5.tar.gz |
IVGCVSW-2367 Add Equal Operator to TfParser
* Unit tests in Equal.cpp
* Fixed error in Network::AddEqualLayer
* Refactored TfParser::Minimum/Equal to get rid of duplicate code
Change-Id: I0ed6f888eb391c995b88be20dc0c1b916dd14c3c
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 48 |
1 files changed, 40 insertions, 8 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 74742a97b3..b646437f36 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -353,6 +353,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "AvgPool", &TfParser::ParseAvgPool }, { "Maximum", &TfParser::ParseMaximum }, { "Minimum", &TfParser::ParseMinimum }, + { "Equal", &TfParser::ParseEqual }, { "Pad", &TfParser::ParsePad }, { "Sub", &TfParser::ParseSub }, }; @@ -1530,8 +1531,8 @@ ParsedTfOperationPtr TfParser::ParseMaximum(const tensorflow::NodeDef& nodeDef, } } -ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef, - const tensorflow::GraphDef& graphDef) +std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> TfParser::ProcessElementwiseInputSlots( + const tensorflow::NodeDef& nodeDef, const std::string& layerName) { std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); @@ -1555,15 +1556,22 @@ ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef, else { throw ParseException( - boost::str( - boost::format("Unsupported broadcast configuration for Minimum operation %1% %2%") - % nodeDef.name() - % CHECK_LOCATION().AsString())); + boost::str( + boost::format("Unsupported broadcast configuration for %1% operation %2% %3%") + % layerName + % nodeDef.name() + % CHECK_LOCATION().AsString())); } } + return {input0Slot, input1Slot}; +} - IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str()); - +ParsedTfOperationPtr TfParser::ProcessElementwiseLayer( + IOutputSlot* input0Slot, + IOutputSlot* input1Slot, + IConnectableLayer* const layer, + const tensorflow::NodeDef& nodeDef) +{ input0Slot->Connect(layer->GetInputSlot(0)); input1Slot->Connect(layer->GetInputSlot(1)); @@ -1584,6 +1592,30 @@ ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef, return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } +ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Equal"); + IOutputSlot* input0Slot = inputLayers.first; + IOutputSlot* input1Slot = inputLayers.second; + + IConnectableLayer* const layer = m_Network->AddEqualLayer(nodeDef.name().c_str()); + + return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef); +} + +ParsedTfOperationPtr TfParser::ParseMinimum(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Minimum"); + IOutputSlot* input0Slot = inputLayers.first; + IOutputSlot* input1Slot = inputLayers.second; + + IConnectableLayer* const layer = m_Network->AddMinimumLayer(nodeDef.name().c_str()); + + return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef); +} + ParsedTfOperationPtr TfParser::ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); |