From 451d95bff7d8cd2d759fbb1b31bbeddd9c332360 Mon Sep 17 00:00:00 2001 From: Bruno Goncalves Date: Tue, 12 Feb 2019 22:59:22 -0200 Subject: Add strided-slice parser to tf-lite Change-Id: I1821d7e8123c76823562dd2e8822c5293fcb18c3 Signed-off-by: Bruno Goncalves --- CMakeLists.txt | 1 + src/armnnTfLiteParser/TfLiteParser.cpp | 57 ++++++++ src/armnnTfLiteParser/TfLiteParser.hpp | 1 + src/armnnTfLiteParser/test/StridedSlice.cpp | 218 ++++++++++++++++++++++++++++ 4 files changed, 277 insertions(+) create mode 100644 src/armnnTfLiteParser/test/StridedSlice.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cdce75196..18c0ca2f35 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -561,6 +561,7 @@ if(BUILD_UNIT_TESTS) src/armnnTfLiteParser/test/SpaceToBatchND.cpp src/armnnTfLiteParser/test/Sub.cpp src/armnnTfLiteParser/test/Squeeze.cpp + src/armnnTfLiteParser/test/StridedSlice.cpp src/armnnTfLiteParser/test/LoadModel.cpp src/armnnTfLiteParser/test/GetBuffer.cpp src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 521c5db299..bc6316ce7e 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -436,6 +436,7 @@ TfLiteParser::TfLiteParser() m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParser::ParseSoftmax; m_ParserFunctions[tflite::BuiltinOperator_SPACE_TO_BATCH_ND] = &TfLiteParser::ParseSpaceToBatchND; m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze; + m_ParserFunctions[tflite::BuiltinOperator_STRIDED_SLICE] = &TfLiteParser::ParseStridedSlice; m_ParserFunctions[tflite::BuiltinOperator_SUB] = &TfLiteParser::ParseSub; m_ParserFunctions[tflite::BuiltinOperator_ADD] = &TfLiteParser::ParseAdd; m_ParserFunctions[tflite::BuiltinOperator_MUL] = &TfLiteParser::ParseMul; @@ -1191,6 +1192,62 @@ void TfLiteParser::ParseSqueeze(size_t subgraphIndex, size_t operatorIndex) RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); } +void TfLiteParser::ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex) +{ + CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); + + auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(inputs.size(), 4); + + auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; + const auto * options = operatorPtr->builtin_options.AsStridedSliceOptions(); + + StridedSliceDescriptor desc; + desc.m_BeginMask = options->begin_mask; + desc.m_EllipsisMask = options->ellipsis_mask; + desc.m_EndMask = options->end_mask; + desc.m_NewAxisMask = options->new_axis_mask; + desc.m_ShrinkAxisMask = options->shrink_axis_mask; + desc.m_DataLayout = armnn::DataLayout::NHWC; + + armnn::TensorInfo beginTensorInfo = ToTensorInfo(inputs[1]); + BufferRawPtr beginBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + + std::vector begin(beginTensorInfo.GetNumElements()); + ::memcpy(begin.data(), beginBufferPtr->data.data(), beginTensorInfo.GetNumBytes()); + + armnn::TensorInfo endTensorInfo = ToTensorInfo(inputs[2]); + BufferRawPtr endBufferPtr = GetBuffer(m_Model, inputs[2]->buffer); + + std::vector end(endTensorInfo.GetNumElements()); + ::memcpy(end.data(), endBufferPtr->data.data(), endTensorInfo.GetNumBytes()); + + armnn::TensorInfo strideTensorInfo = ToTensorInfo(inputs[3]); + BufferRawPtr strideBufferPtr = GetBuffer(m_Model, inputs[3]->buffer); + + std::vector stride(strideTensorInfo.GetNumElements()); + ::memcpy(stride.data(), strideBufferPtr->data.data(), strideTensorInfo.GetNumBytes()); + + desc.m_Begin = begin; + desc.m_End = end; + desc.m_Stride = stride; + + auto layerName = boost::str(boost::format("StridedSlice:%1%:%2%") % subgraphIndex % operatorIndex); + IConnectableLayer* layer = m_Network->AddStridedSliceLayer(desc, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]}); + + auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex)); + RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]}); +} + void TfLiteParser::ParseSub(size_t subgraphIndex, size_t operatorIndex) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 69403769e6..3fe4809aa2 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -109,6 +109,7 @@ private: void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex); void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex); void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex); + void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex); void ParseSub(size_t subgraphIndex, size_t operatorIndex); void ParseAdd(size_t subgraphIndex, size_t operatorIndex); void ParseMul(size_t subgraphIndex, size_t operatorIndex); diff --git a/src/armnnTfLiteParser/test/StridedSlice.cpp b/src/armnnTfLiteParser/test/StridedSlice.cpp new file mode 100644 index 0000000000..74c77c049f --- /dev/null +++ b/src/armnnTfLiteParser/test/StridedSlice.cpp @@ -0,0 +1,218 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include +#include "ParserFlatbuffersFixture.hpp" +#include "../TfLiteParser.hpp" + +#include +#include + +BOOST_AUTO_TEST_SUITE(TensorflowLiteParser) + +struct StridedSliceFixture : public ParserFlatbuffersFixture +{ + explicit StridedSliceFixture(const std::string & inputShape, + const std::string & outputShape, + const std::string & beginData, + const std::string & endData, + const std::string & stridesData, + int beginMask = 0, + int endMask = 0) + { + m_JsonString = R"( + { + "version": 3, + "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ], + "subgraphs": [ { + "tensors": [ + { + "shape": )" + inputShape + R"(, + "type": "FLOAT32", + "buffer": 0, + "name": "inputTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": [ 4 ], + "type": "INT32", + "buffer": 1, + "name": "beginTensor", + "quantization": { + } + }, + { + "shape": [ 4 ], + "type": "INT32", + "buffer": 2, + "name": "endTensor", + "quantization": { + } + }, + { + "shape": [ 4 ], + "type": "INT32", + "buffer": 3, + "name": "stridesTensor", + "quantization": { + } + }, + { + "shape": )" + outputShape + R"( , + "type": "FLOAT32", + "buffer": 4, + "name": "outputTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + } + ], + "inputs": [ 0, 1, 2, 3 ], + "outputs": [ 4 ], + "operators": [ + { + "opcode_index": 0, + "inputs": [ 0, 1, 2, 3 ], + "outputs": [ 4 ], + "builtin_options_type": "StridedSliceOptions", + "builtin_options": { + "begin_mask": )" + std::to_string(beginMask) + R"(, + "end_mask": )" + std::to_string(endMask) + R"( + }, + "custom_options_format": "FLEXBUFFERS" + } + ], + } ], + "buffers" : [ + { }, + { "data": )" + beginData + R"(, }, + { "data": )" + endData + R"(, }, + { "data": )" + stridesData + R"(, }, + { } + ] + } + )"; + Setup(); + } +}; + +struct StridedSlice4DFixture : StridedSliceFixture +{ + StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape + "[ 1, 2, 3, 1 ]", // outputShape + "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData + "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData + "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture) +{ + RunTest<4, armnn::DataType::Float32>( + 0, + {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + + {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}}); +} + +struct StridedSlice4DReverseFixture : StridedSliceFixture +{ + StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape + "[ 1, 2, 3, 1 ]", // outputShape + "[ 1,0,0,0, " + "255,255,255,255, " + "0,0,0,0, " + "0,0,0,0 ]", // beginData [ 1 -1 0 0 ] + "[ 2,0,0,0, " + "253,255,255,255, " + "3,0,0,0, " + "1,0,0,0 ]", // endData [ 2 -3 3 1 ] + "[ 1,0,0,0, " + "255,255,255,255, " + "1,0,0,0, " + "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ] + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture) +{ + RunTest<4, armnn::DataType::Float32>( + 0, + {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + + {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}}); +} + +struct StridedSliceSimpleStrideFixture : StridedSliceFixture +{ + StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape + "[ 2, 1, 2, 1 ]", // outputShape + "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData + "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData + "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture) +{ + RunTest<4, armnn::DataType::Float32>( + 0, + {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + + {{"outputTensor", { 1.0f, 1.0f, + + 5.0f, 5.0f }}}); +} + +struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture +{ + StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape + "[ 3, 2, 3, 1 ]", // outputShape + "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData + "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData + "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData + (1 << 4) - 1, // beginMask + (1 << 4) - 1 // endMask + ) {} +}; + +BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture) +{ + RunTest<4, armnn::DataType::Float32>( + 0, + {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, + + {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, + + 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, + + 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}); +} + +BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1