aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp49
1 files changed, 49 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index af86619249..d65af2365b 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -368,6 +368,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Softmax", &TfParser::ParseSoftmax },
{ "Softplus", &TfParser::ParseSoftplus },
{ "Split", &TfParser::ParseSplit },
+ { "StridedSlice", &TfParser::ParseStridedSlice },
{ "Tanh", &TfParser::ParseTanh },
{ "MaxPool", &TfParser::ParseMaxPool },
{ "AvgPool", &TfParser::ParseAvgPool },
@@ -2760,6 +2761,54 @@ ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef,
return AddActivationLayer(nodeDef, activationDesc);
}
+ParsedTfOperationPtr TfParser::ParseStridedSlice(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+ unsigned int numInputs = static_cast<unsigned int>(nodes.size());
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
+
+ ParsedConstTfOperation<int32_t>* beginNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[1].m_IndexedValue);
+ std::vector<int32_t> beginTensorData;
+ beginNode->GetConstTensor(beginTensorData);
+
+ ParsedConstTfOperation<int32_t>* endNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[2].m_IndexedValue);
+ std::vector<int32_t> endTensorData;
+ endNode->GetConstTensor(endTensorData);
+
+ ParsedConstTfOperation<int32_t>* stridesNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[3].m_IndexedValue);
+ std::vector<int32_t> stridesTensorData;
+ stridesNode->GetConstTensor(stridesTensorData);
+
+ StridedSliceDescriptor desc;
+ desc.m_Begin = beginTensorData;
+ desc.m_End = endTensorData;
+ desc.m_Stride = stridesTensorData;
+ desc.m_BeginMask = ReadMandatoryNodeInt32Attribute(nodeDef, "begin_mask");
+ desc.m_EndMask = ReadMandatoryNodeInt32Attribute(nodeDef, "end_mask");
+ desc.m_EllipsisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "ellipsis_mask");
+ desc.m_NewAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "new_axis_mask");
+ desc.m_ShrinkAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "shrink_axis_mask");
+ desc.m_DataLayout = armnn::DataLayout::NHWC;
+ IConnectableLayer* const layer = m_Network->AddStridedSliceLayer(desc, nodeDef.name().c_str());
+
+ IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = prevLayerSlot.GetTensorInfo();
+
+ TensorInfo outputTensorInfo;
+ CalculateStridedSliceOutputTensorInfo(inputTensorInfo, desc, outputTensorInfo);
+
+ prevLayerSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
{
boost::ignore_unused(graphDef);