From 5e90aab1cc25681c3e02b4d4436c24ee43400e91 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 14 Feb 2020 14:46:51 +0000 Subject: COMPMID-3059: Add TF parser support for StridedSlice Signed-off-by: Georgios Pinitas Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7 --- CMakeLists.txt | 1 + src/armnnTfLiteParser/test/StridedSlice.cpp | 1 - src/armnnTfParser/TfParser.cpp | 49 +++++ src/armnnTfParser/TfParser.hpp | 1 + src/armnnTfParser/test/Gather.cpp | 7 +- src/armnnTfParser/test/StridedSlice.cpp | 283 ++++++++++++++++++++++++++++ src/armnnUtils/ParserHelper.cpp | 30 +++ src/armnnUtils/ParserHelper.hpp | 5 + 8 files changed, 373 insertions(+), 4 deletions(-) create mode 100644 src/armnnTfParser/test/StridedSlice.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index a9b1e64e4c..7ce9c42801 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -711,6 +711,7 @@ if(BUILD_UNIT_TESTS) src/armnnTfParser/test/Split.cpp src/armnnTfParser/test/Squeeze.cpp src/armnnTfParser/test/Sub.cpp + src/armnnTfParser/test/StridedSlice.cpp ) endif() diff --git a/src/armnnTfLiteParser/test/StridedSlice.cpp b/src/armnnTfLiteParser/test/StridedSlice.cpp index 74c77c049f..91427a6420 100644 --- a/src/armnnTfLiteParser/test/StridedSlice.cpp +++ b/src/armnnTfLiteParser/test/StridedSlice.cpp @@ -8,7 +8,6 @@ #include "../TfLiteParser.hpp" #include -#include BOOST_AUTO_TEST_SUITE(TensorflowLiteParser) diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index af86619249..d65af2365b 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -368,6 +368,7 @@ const std::map TfParser::ms_Ope { "Softmax", &TfParser::ParseSoftmax }, { "Softplus", &TfParser::ParseSoftplus }, { "Split", &TfParser::ParseSplit }, + { "StridedSlice", &TfParser::ParseStridedSlice }, { "Tanh", &TfParser::ParseTanh }, { "MaxPool", &TfParser::ParseMaxPool }, { "AvgPool", &TfParser::ParseAvgPool }, @@ -2760,6 +2761,54 @@ ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef, return AddActivationLayer(nodeDef, activationDesc); } +ParsedTfOperationPtr TfParser::ParseStridedSlice(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector nodes = GetTfInputNodes(nodeDef); + unsigned int numInputs = static_cast(nodes.size()); + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); + + ParsedConstTfOperation* beginNode = + boost::polymorphic_downcast *>(inputs[1].m_IndexedValue); + std::vector beginTensorData; + beginNode->GetConstTensor(beginTensorData); + + ParsedConstTfOperation* endNode = + boost::polymorphic_downcast *>(inputs[2].m_IndexedValue); + std::vector endTensorData; + endNode->GetConstTensor(endTensorData); + + ParsedConstTfOperation* stridesNode = + boost::polymorphic_downcast *>(inputs[3].m_IndexedValue); + std::vector stridesTensorData; + stridesNode->GetConstTensor(stridesTensorData); + + StridedSliceDescriptor desc; + desc.m_Begin = beginTensorData; + desc.m_End = endTensorData; + desc.m_Stride = stridesTensorData; + desc.m_BeginMask = ReadMandatoryNodeInt32Attribute(nodeDef, "begin_mask"); + desc.m_EndMask = ReadMandatoryNodeInt32Attribute(nodeDef, "end_mask"); + desc.m_EllipsisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "ellipsis_mask"); + desc.m_NewAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "new_axis_mask"); + desc.m_ShrinkAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "shrink_axis_mask"); + desc.m_DataLayout = armnn::DataLayout::NHWC; + IConnectableLayer* const layer = m_Network->AddStridedSliceLayer(desc, nodeDef.name().c_str()); + + IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = prevLayerSlot.GetTensorInfo(); + + TensorInfo outputTensorInfo; + CalculateStridedSliceOutputTensorInfo(inputTensorInfo, desc, outputTensorInfo); + + prevLayerSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + return std::make_unique(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { boost::ignore_unused(graphDef); diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp index 8442ca0fb3..a7d02be33d 100644 --- a/src/armnnTfParser/TfParser.hpp +++ b/src/armnnTfParser/TfParser.hpp @@ -156,6 +156,7 @@ private: ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseStridedSlice(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp index f40dc57556..a6c20fd63e 100644 --- a/src/armnnTfParser/test/Gather.cpp +++ b/src/armnnTfParser/test/Gather.cpp @@ -12,9 +12,10 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser) +namespace { // helper for setting the dimensions in prototxt void dimsHelper(const std::vector& dims, std::string& text){ - for(u_int i=0; i& dims, std::string& text){ // helper for converting from integer to octal representation void octalHelper(const std::vector& indicesContent, std::string& text){ - for (unsigned int i = 0; i < indicesContent.size(); ++i) - { + for(unsigned int i = 0; i < indicesContent.size(); ++i) { text.append(armnnUtils::ConvertInt32ToOctalString(static_cast(indicesContent[i]))); } } +} // namespace struct GatherFixture : public armnnUtils::ParserPrototxtFixture { diff --git a/src/armnnTfParser/test/StridedSlice.cpp b/src/armnnTfParser/test/StridedSlice.cpp new file mode 100644 index 0000000000..89faf75679 --- /dev/null +++ b/src/armnnTfParser/test/StridedSlice.cpp @@ -0,0 +1,283 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "armnnTfParser/ITfParser.hpp" + +#include "ParserPrototxtFixture.hpp" +#include + +#include + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +namespace { +// helper for setting the dimensions in prototxt +void shapeHelper(const armnn::TensorShape& shape, std::string& text){ + for(u_int i = 0; i < shape.GetNumDimensions(); ++i) { + text.append(R"(dim { + size: )"); + text.append(std::to_string(shape[i])); + text.append(R"( + })"); + } +} + +// helper for converting from integer to octal representation +void octalHelper(const std::vector& content, std::string& text){ + for (unsigned int i = 0; i < content.size(); ++i) + { + text.append(armnnUtils::ConvertInt32ToOctalString(static_cast(content[i]))); + } +} +} // namespace + +struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture +{ + StridedSliceFixture(const armnn::TensorShape& inputShape, + const std::vector& beginData, + const std::vector& endData, + const std::vector& stridesData, + int beginMask = 0, + int endMask = 0, + int ellipsisMask = 0, + int newAxisMask = 0, + int shrinkAxisMask = 0) + { + m_Prototext = R"( + node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape {)"; + shapeHelper(inputShape, m_Prototext); + m_Prototext.append(R"( + } + } + } + } + node { + name: "begin" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: )"); + m_Prototext += std::to_string(beginData.size()); + m_Prototext.append(R"( + } + } + tensor_content: ")"); + octalHelper(beginData, m_Prototext); + m_Prototext.append(R"(" + } + } + } + } + node { + name: "end" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: )"); + m_Prototext += std::to_string(endData.size()); + m_Prototext.append(R"( + } + } + tensor_content: ")"); + octalHelper(endData, m_Prototext); + m_Prototext.append(R"(" + } + } + } + } + node { + name: "strides" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: )"); + m_Prototext += std::to_string(stridesData.size()); + m_Prototext.append(R"( + } + } + tensor_content: ")"); + octalHelper(stridesData, m_Prototext); + m_Prototext.append(R"(" + } + } + } + } + node { + name: "output" + op: "StridedSlice" + input: "input" + input: "begin" + input: "end" + input: "strides" + attr { + key: "begin_mask" + value { + i: )"); + m_Prototext += std::to_string(beginMask); + m_Prototext.append(R"( + } + } + attr { + key: "end_mask" + value { + i: )"); + m_Prototext += std::to_string(endMask); + m_Prototext.append(R"( + } + } + attr { + key: "ellipsis_mask" + value { + i: )"); + m_Prototext += std::to_string(ellipsisMask); + m_Prototext.append(R"( + } + } + attr { + key: "new_axis_mask" + value { + i: )"); + m_Prototext += std::to_string(newAxisMask); + m_Prototext.append(R"( + } + } + attr { + key: "shrink_axis_mask" + value { + i: )"); + m_Prototext += std::to_string(shrinkAxisMask); + m_Prototext.append(R"( + } + } + })"); + + Setup({ { "input", inputShape } }, { "output" }); + } +}; + +struct StridedSlice4DFixture : StridedSliceFixture +{ + StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape + { 1, 0, 0, 0 }, // beginData + { 2, 2, 3, 1 }, // endData + { 1, 1, 1, 1 } // stridesData + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture) +{ + RunTest<4>( + {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}}); +} + +struct StridedSlice4DReverseFixture : StridedSliceFixture +{ + + StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape + { 1, -1, 0, 0 }, // beginData + { 2, -3, 3, 1 }, // endData + { 1, -1, 1, 1 } // stridesData + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture) +{ + RunTest<4>( + {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}}); +} + +struct StridedSliceSimpleStrideFixture : StridedSliceFixture +{ + StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape + { 0, 0, 0, 0 }, // beginData + { 3, 2, 3, 1 }, // endData + { 2, 2, 2, 1 } // stridesData + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture) +{ + RunTest<4>( + {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + {{"output", { 1.0f, 1.0f, + 5.0f, 5.0f }}}); +} + +struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture +{ + StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape + { 1, 1, 1, 1 }, // beginData + { 1, 1, 1, 1 }, // endData + { 1, 1, 1, 1 }, // stridesData + (1 << 4) - 1, // beginMask + (1 << 4) - 1 // endMask + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture) +{ + RunTest<4>( + {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index ca6e42696e..9406553dff 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, } } + +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo) +{ + const armnn::TensorShape& inputShape = inputTensorInfo.GetShape(); + + std::vector outputShapeVector; + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + if (desc.m_ShrinkAxisMask & (1 << i)) + { + continue; + } + + int stride = desc.m_Stride[i]; + int start = desc.GetStartForAxis(inputShape, i); + int stop = desc.GetStopForAxis(inputShape, i, start); + + int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride : + ((start - stop) - stride - 1) / -stride; + + newSize = std::max(0, newSize); + + outputShapeVector.push_back(static_cast(newSize)); + } + + armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]); + outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType()); +} } // namespace armnnUtils diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp index d63408804d..28c7964ac1 100644 --- a/src/armnnUtils/ParserHelper.hpp +++ b/src/armnnUtils/ParserHelper.hpp @@ -25,4 +25,9 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, bool keepDims, armnn::TensorInfo& outputTensorInfo); +/// Create output tensor info for a StridedSlice operator +void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo, + const armnn::StridedSliceDescriptor& desc, + armnn::TensorInfo& outputTensorInfo); + } // namespace armnnUtils -- cgit v1.2.1