diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index af86619249..d65af2365b 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -368,6 +368,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "Softmax", &TfParser::ParseSoftmax }, { "Softplus", &TfParser::ParseSoftplus }, { "Split", &TfParser::ParseSplit }, + { "StridedSlice", &TfParser::ParseStridedSlice }, { "Tanh", &TfParser::ParseTanh }, { "MaxPool", &TfParser::ParseMaxPool }, { "AvgPool", &TfParser::ParseAvgPool }, @@ -2760,6 +2761,54 @@ ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef, return AddActivationLayer(nodeDef, activationDesc); } +ParsedTfOperationPtr TfParser::ParseStridedSlice(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef); + unsigned int numInputs = static_cast<unsigned int>(nodes.size()); + std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); + + ParsedConstTfOperation<int32_t>* beginNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[1].m_IndexedValue); + std::vector<int32_t> beginTensorData; + beginNode->GetConstTensor(beginTensorData); + + ParsedConstTfOperation<int32_t>* endNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[2].m_IndexedValue); + std::vector<int32_t> endTensorData; + endNode->GetConstTensor(endTensorData); + + ParsedConstTfOperation<int32_t>* stridesNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[3].m_IndexedValue); + std::vector<int32_t> stridesTensorData; + stridesNode->GetConstTensor(stridesTensorData); + + StridedSliceDescriptor desc; + desc.m_Begin = beginTensorData; + desc.m_End = endTensorData; + desc.m_Stride = stridesTensorData; + desc.m_BeginMask = ReadMandatoryNodeInt32Attribute(nodeDef, "begin_mask"); + desc.m_EndMask = ReadMandatoryNodeInt32Attribute(nodeDef, "end_mask"); + desc.m_EllipsisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "ellipsis_mask"); + desc.m_NewAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "new_axis_mask"); + desc.m_ShrinkAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "shrink_axis_mask"); + desc.m_DataLayout = armnn::DataLayout::NHWC; + IConnectableLayer* const layer = m_Network->AddStridedSliceLayer(desc, nodeDef.name().c_str()); + + IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = prevLayerSlot.GetTensorInfo(); + + TensorInfo outputTensorInfo; + CalculateStridedSliceOutputTensorInfo(inputTensorInfo, desc, outputTensorInfo); + + prevLayerSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { boost::ignore_unused(graphDef); |