diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2020-02-14 14:46:51 +0000 |
---|---|---|
committer | Derek Lamberti <derek.lamberti@arm.com> | 2020-02-18 10:10:07 +0000 |
commit | 5e90aab1cc25681c3e02b4d4436c24ee43400e91 (patch) | |
tree | 1121284177ec93b8e8e8ac07e98209639a562a8f /src/armnnTfParser/TfParser.cpp | |
parent | 0c2eeac6347533a1d3d456aebea492f5123388f3 (diff) | |
download | armnn-5e90aab1cc25681c3e02b4d4436c24ee43400e91.tar.gz |
COMPMID-3059: Add TF parser support for StridedSlice
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index af86619249..d65af2365b 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -368,6 +368,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "Softmax", &TfParser::ParseSoftmax }, { "Softplus", &TfParser::ParseSoftplus }, { "Split", &TfParser::ParseSplit }, + { "StridedSlice", &TfParser::ParseStridedSlice }, { "Tanh", &TfParser::ParseTanh }, { "MaxPool", &TfParser::ParseMaxPool }, { "AvgPool", &TfParser::ParseAvgPool }, @@ -2760,6 +2761,54 @@ ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef, return AddActivationLayer(nodeDef, activationDesc); } +ParsedTfOperationPtr TfParser::ParseStridedSlice(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef); + unsigned int numInputs = static_cast<unsigned int>(nodes.size()); + std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); + + ParsedConstTfOperation<int32_t>* beginNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[1].m_IndexedValue); + std::vector<int32_t> beginTensorData; + beginNode->GetConstTensor(beginTensorData); + + ParsedConstTfOperation<int32_t>* endNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[2].m_IndexedValue); + std::vector<int32_t> endTensorData; + endNode->GetConstTensor(endTensorData); + + ParsedConstTfOperation<int32_t>* stridesNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[3].m_IndexedValue); + std::vector<int32_t> stridesTensorData; + stridesNode->GetConstTensor(stridesTensorData); + + StridedSliceDescriptor desc; + desc.m_Begin = beginTensorData; + desc.m_End = endTensorData; + desc.m_Stride = stridesTensorData; + desc.m_BeginMask = ReadMandatoryNodeInt32Attribute(nodeDef, "begin_mask"); + desc.m_EndMask = ReadMandatoryNodeInt32Attribute(nodeDef, "end_mask"); + desc.m_EllipsisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "ellipsis_mask"); + desc.m_NewAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "new_axis_mask"); + desc.m_ShrinkAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "shrink_axis_mask"); + desc.m_DataLayout = armnn::DataLayout::NHWC; + IConnectableLayer* const layer = m_Network->AddStridedSliceLayer(desc, nodeDef.name().c_str()); + + IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = prevLayerSlot.GetTensorInfo(); + + TensorInfo outputTensorInfo; + CalculateStridedSliceOutputTensorInfo(inputTensorInfo, desc, outputTensorInfo); + + prevLayerSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { boost::ignore_unused(graphDef); |