aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2020-02-14 14:46:51 +0000
committerDerek Lamberti <derek.lamberti@arm.com>2020-02-18 10:10:07 +0000
commit5e90aab1cc25681c3e02b4d4436c24ee43400e91 (patch)
tree1121284177ec93b8e8e8ac07e98209639a562a8f
parent0c2eeac6347533a1d3d456aebea492f5123388f3 (diff)
downloadarmnn-5e90aab1cc25681c3e02b4d4436c24ee43400e91.tar.gz
COMPMID-3059: Add TF parser support for StridedSlice
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I31f25f26a50c9054b5650b1be127c84194b56be7
-rw-r--r--CMakeLists.txt1
-rw-r--r--src/armnnTfLiteParser/test/StridedSlice.cpp1
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp49
-rw-r--r--src/armnnTfParser/TfParser.hpp1
-rw-r--r--src/armnnTfParser/test/Gather.cpp7
-rw-r--r--src/armnnTfParser/test/StridedSlice.cpp283
-rw-r--r--src/armnnUtils/ParserHelper.cpp30
-rw-r--r--src/armnnUtils/ParserHelper.hpp5
8 files changed, 373 insertions, 4 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a9b1e64e4c..7ce9c42801 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -711,6 +711,7 @@ if(BUILD_UNIT_TESTS)
src/armnnTfParser/test/Split.cpp
src/armnnTfParser/test/Squeeze.cpp
src/armnnTfParser/test/Sub.cpp
+ src/armnnTfParser/test/StridedSlice.cpp
)
endif()
diff --git a/src/armnnTfLiteParser/test/StridedSlice.cpp b/src/armnnTfLiteParser/test/StridedSlice.cpp
index 74c77c049f..91427a6420 100644
--- a/src/armnnTfLiteParser/test/StridedSlice.cpp
+++ b/src/armnnTfLiteParser/test/StridedSlice.cpp
@@ -8,7 +8,6 @@
#include "../TfLiteParser.hpp"
#include <string>
-#include <iostream>
BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index af86619249..d65af2365b 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -368,6 +368,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Softmax", &TfParser::ParseSoftmax },
{ "Softplus", &TfParser::ParseSoftplus },
{ "Split", &TfParser::ParseSplit },
+ { "StridedSlice", &TfParser::ParseStridedSlice },
{ "Tanh", &TfParser::ParseTanh },
{ "MaxPool", &TfParser::ParseMaxPool },
{ "AvgPool", &TfParser::ParseAvgPool },
@@ -2760,6 +2761,54 @@ ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef,
return AddActivationLayer(nodeDef, activationDesc);
}
+ParsedTfOperationPtr TfParser::ParseStridedSlice(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+ unsigned int numInputs = static_cast<unsigned int>(nodes.size());
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
+
+ ParsedConstTfOperation<int32_t>* beginNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[1].m_IndexedValue);
+ std::vector<int32_t> beginTensorData;
+ beginNode->GetConstTensor(beginTensorData);
+
+ ParsedConstTfOperation<int32_t>* endNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[2].m_IndexedValue);
+ std::vector<int32_t> endTensorData;
+ endNode->GetConstTensor(endTensorData);
+
+ ParsedConstTfOperation<int32_t>* stridesNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t> *>(inputs[3].m_IndexedValue);
+ std::vector<int32_t> stridesTensorData;
+ stridesNode->GetConstTensor(stridesTensorData);
+
+ StridedSliceDescriptor desc;
+ desc.m_Begin = beginTensorData;
+ desc.m_End = endTensorData;
+ desc.m_Stride = stridesTensorData;
+ desc.m_BeginMask = ReadMandatoryNodeInt32Attribute(nodeDef, "begin_mask");
+ desc.m_EndMask = ReadMandatoryNodeInt32Attribute(nodeDef, "end_mask");
+ desc.m_EllipsisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "ellipsis_mask");
+ desc.m_NewAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "new_axis_mask");
+ desc.m_ShrinkAxisMask = ReadMandatoryNodeInt32Attribute(nodeDef, "shrink_axis_mask");
+ desc.m_DataLayout = armnn::DataLayout::NHWC;
+ IConnectableLayer* const layer = m_Network->AddStridedSliceLayer(desc, nodeDef.name().c_str());
+
+ IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = prevLayerSlot.GetTensorInfo();
+
+ TensorInfo outputTensorInfo;
+ CalculateStridedSliceOutputTensorInfo(inputTensorInfo, desc, outputTensorInfo);
+
+ prevLayerSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
{
boost::ignore_unused(graphDef);
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 8442ca0fb3..a7d02be33d 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -156,6 +156,7 @@ private:
ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseStridedSlice(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp
index f40dc57556..a6c20fd63e 100644
--- a/src/armnnTfParser/test/Gather.cpp
+++ b/src/armnnTfParser/test/Gather.cpp
@@ -12,9 +12,10 @@
BOOST_AUTO_TEST_SUITE(TensorflowParser)
+namespace {
// helper for setting the dimensions in prototxt
void dimsHelper(const std::vector<int>& dims, std::string& text){
- for(u_int i=0; i<dims.size(); ++i){
+ for(u_int i = 0; i < dims.size(); ++i) {
text.append(R"(dim {
size: )");
text.append(std::to_string(dims[i]));
@@ -25,11 +26,11 @@ void dimsHelper(const std::vector<int>& dims, std::string& text){
// helper for converting from integer to octal representation
void octalHelper(const std::vector<int>& indicesContent, std::string& text){
- for (unsigned int i = 0; i < indicesContent.size(); ++i)
- {
+ for(unsigned int i = 0; i < indicesContent.size(); ++i) {
text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
}
}
+} // namespace
struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
diff --git a/src/armnnTfParser/test/StridedSlice.cpp b/src/armnnTfParser/test/StridedSlice.cpp
new file mode 100644
index 0000000000..89faf75679
--- /dev/null
+++ b/src/armnnTfParser/test/StridedSlice.cpp
@@ -0,0 +1,283 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnTfParser/ITfParser.hpp"
+
+#include "ParserPrototxtFixture.hpp"
+#include <PrototxtConversions.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+namespace {
+// helper for setting the dimensions in prototxt
+void shapeHelper(const armnn::TensorShape& shape, std::string& text){
+ for(u_int i = 0; i < shape.GetNumDimensions(); ++i) {
+ text.append(R"(dim {
+ size: )");
+ text.append(std::to_string(shape[i]));
+ text.append(R"(
+ })");
+ }
+}
+
+// helper for converting from integer to octal representation
+void octalHelper(const std::vector<int>& content, std::string& text){
+ for (unsigned int i = 0; i < content.size(); ++i)
+ {
+ text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(content[i])));
+ }
+}
+} // namespace
+
+struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ StridedSliceFixture(const armnn::TensorShape& inputShape,
+ const std::vector<int>& beginData,
+ const std::vector<int>& endData,
+ const std::vector<int>& stridesData,
+ int beginMask = 0,
+ int endMask = 0,
+ int ellipsisMask = 0,
+ int newAxisMask = 0,
+ int shrinkAxisMask = 0)
+ {
+ m_Prototext = R"(
+ node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {)";
+ shapeHelper(inputShape, m_Prototext);
+ m_Prototext.append(R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "begin"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: )");
+ m_Prototext += std::to_string(beginData.size());
+ m_Prototext.append(R"(
+ }
+ }
+ tensor_content: ")");
+ octalHelper(beginData, m_Prototext);
+ m_Prototext.append(R"("
+ }
+ }
+ }
+ }
+ node {
+ name: "end"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: )");
+ m_Prototext += std::to_string(endData.size());
+ m_Prototext.append(R"(
+ }
+ }
+ tensor_content: ")");
+ octalHelper(endData, m_Prototext);
+ m_Prototext.append(R"("
+ }
+ }
+ }
+ }
+ node {
+ name: "strides"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: )");
+ m_Prototext += std::to_string(stridesData.size());
+ m_Prototext.append(R"(
+ }
+ }
+ tensor_content: ")");
+ octalHelper(stridesData, m_Prototext);
+ m_Prototext.append(R"("
+ }
+ }
+ }
+ }
+ node {
+ name: "output"
+ op: "StridedSlice"
+ input: "input"
+ input: "begin"
+ input: "end"
+ input: "strides"
+ attr {
+ key: "begin_mask"
+ value {
+ i: )");
+ m_Prototext += std::to_string(beginMask);
+ m_Prototext.append(R"(
+ }
+ }
+ attr {
+ key: "end_mask"
+ value {
+ i: )");
+ m_Prototext += std::to_string(endMask);
+ m_Prototext.append(R"(
+ }
+ }
+ attr {
+ key: "ellipsis_mask"
+ value {
+ i: )");
+ m_Prototext += std::to_string(ellipsisMask);
+ m_Prototext.append(R"(
+ }
+ }
+ attr {
+ key: "new_axis_mask"
+ value {
+ i: )");
+ m_Prototext += std::to_string(newAxisMask);
+ m_Prototext.append(R"(
+ }
+ }
+ attr {
+ key: "shrink_axis_mask"
+ value {
+ i: )");
+ m_Prototext += std::to_string(shrinkAxisMask);
+ m_Prototext.append(R"(
+ }
+ }
+ })");
+
+ Setup({ { "input", inputShape } }, { "output" });
+ }
+};
+
+struct StridedSlice4DFixture : StridedSliceFixture
+{
+ StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
+ { 1, 0, 0, 0 }, // beginData
+ { 2, 2, 3, 1 }, // endData
+ { 1, 1, 1, 1 } // stridesData
+ ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
+{
+ RunTest<4>(
+ {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+ 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+ 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+ {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
+}
+
+struct StridedSlice4DReverseFixture : StridedSliceFixture
+{
+
+ StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
+ { 1, -1, 0, 0 }, // beginData
+ { 2, -3, 3, 1 }, // endData
+ { 1, -1, 1, 1 } // stridesData
+ ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
+{
+ RunTest<4>(
+ {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+ 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+ 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+ {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
+}
+
+struct StridedSliceSimpleStrideFixture : StridedSliceFixture
+{
+ StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
+ { 0, 0, 0, 0 }, // beginData
+ { 3, 2, 3, 1 }, // endData
+ { 2, 2, 2, 1 } // stridesData
+ ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
+{
+ RunTest<4>(
+ {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+ 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+ 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+ {{"output", { 1.0f, 1.0f,
+ 5.0f, 5.0f }}});
+}
+
+struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
+{
+ StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
+ { 1, 1, 1, 1 }, // beginData
+ { 1, 1, 1, 1 }, // endData
+ { 1, 1, 1, 1 }, // stridesData
+ (1 << 4) - 1, // beginMask
+ (1 << 4) - 1 // endMask
+ ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
+{
+ RunTest<4>(
+ {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+ 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+ 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
+ {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
+ 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
+ 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index ca6e42696e..9406553dff 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -101,4 +101,34 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
}
}
+
+void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::StridedSliceDescriptor& desc,
+ armnn::TensorInfo& outputTensorInfo)
+{
+ const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
+
+ std::vector<unsigned int> outputShapeVector;
+ for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ if (desc.m_ShrinkAxisMask & (1 << i))
+ {
+ continue;
+ }
+
+ int stride = desc.m_Stride[i];
+ int start = desc.GetStartForAxis(inputShape, i);
+ int stop = desc.GetStopForAxis(inputShape, i, start);
+
+ int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
+ ((start - stop) - stride - 1) / -stride;
+
+ newSize = std::max(0, newSize);
+
+ outputShapeVector.push_back(static_cast<unsigned int>(newSize));
+ }
+
+ armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
+ outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
+}
} // namespace armnnUtils
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
index d63408804d..28c7964ac1 100644
--- a/src/armnnUtils/ParserHelper.hpp
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -25,4 +25,9 @@ void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
bool keepDims,
armnn::TensorInfo& outputTensorInfo);
+/// Create output tensor info for a StridedSlice operator
+void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::StridedSliceDescriptor& desc,
+ armnn::TensorInfo& outputTensorInfo);
+
} // namespace armnnUtils