aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorjimfly01 <jim.flynn@arm.com>2018-12-18 16:24:51 +0000
committerJim Flynn Arm <jim.flynn@arm.com>2018-12-20 17:37:20 +0000
commita06bf31afabfb84e60740ea3219406ab13c8e6a6 (patch)
tree45c5b05799be32ec3c70ca3d25b3643a8baf3f0b /src/armnnTfParser/TfParser.cpp
parentf446432f4c21a64ffb92552c5e1906194fb98558 (diff)
downloadarmnn-a06bf31afabfb84e60740ea3219406ab13c8e6a6.tar.gz
IVGCVSW-2380 Add Greater operator to TfParser
* Unit tests in Greater.cpp Change-Id: Ifb3e4c33be2d6235e33889bb63e6abd78bd7d8b6
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b646437f36..45c039bb15 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -333,6 +333,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
{ "ExpandDims", &TfParser::ParseExpandDims },
{ "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
+ { "Greater", &TfParser::ParseGreater},
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
{ "MatMul", &TfParser::ParseMatMul },
@@ -1592,6 +1593,18 @@ ParsedTfOperationPtr TfParser::ProcessElementwiseLayer(
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseGreater(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> inputLayers = ProcessElementwiseInputSlots(nodeDef, "Greater");
+ IOutputSlot* input0Slot = inputLayers.first;
+ IOutputSlot* input1Slot = inputLayers.second;
+
+ IConnectableLayer* const layer = m_Network->AddGreaterLayer(nodeDef.name().c_str());
+
+ return ProcessElementwiseLayer(input0Slot, input1Slot, layer, nodeDef);
+}
+
ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{