aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik <sadik.armagan@arm.com>2018-09-19 15:30:00 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:56 +0100
commitb94967bae76196f28a86edc90b4bbe12618e2360 (patch)
tree6cb277afad0348eb338c1ccdb5a0210d7062bd38
parentb9bf946704608c9a5e48189e338cc6efebae1d65 (diff)
downloadarmnn-b94967bae76196f28a86edc90b4bbe12618e2360.tar.gz
IVGCVSW-1650 Add Support for Reshape layer on TF Lite parser
* Added Reshape operator support for the TfLite Parser. Change-Id: I64a5650dac089905a402be4a9cb6032aa0d81f00
-rw-r--r--CMakeLists.txt1
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp64
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp4
-rw-r--r--src/armnnTfLiteParser/test/Reshape.cpp141
4 files changed, 209 insertions, 1 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e4ed9b4515..429046142f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -728,6 +728,7 @@ if(BUILD_UNIT_TESTS)
src/armnnTfLiteParser/test/GetSubgraphInputsOutputs.cpp
src/armnnTfLiteParser/test/GetInputsOutputs.cpp
src/armnnTfLiteParser/test/Activations.cpp
+ src/armnnTfLiteParser/test/Reshape.cpp
)
endif()
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index dd1f5773af..13e4604490 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -24,6 +24,7 @@
#include <fstream>
#include <algorithm>
#include <limits>
+#include <numeric>
using namespace armnn;
using armnn::CheckLocation;
@@ -457,6 +458,7 @@ TfLiteParser::TfLiteParser()
m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE] = &TfLiteParser::ParseSqueeze;
m_ParserFunctions[tflite::BuiltinOperator_RELU] = &TfLiteParser::ParseRelu;
m_ParserFunctions[tflite::BuiltinOperator_RELU6] = &TfLiteParser::ParseRelu6;
+ m_ParserFunctions[tflite::BuiltinOperator_RESHAPE] = &TfLiteParser::ParseReshape;
}
void TfLiteParser::ResetParser()
@@ -1033,6 +1035,68 @@ void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
}
+armnn::TensorInfo TfLiteParser::OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
+ const std::vector<int32_t> & targetDimsIn)
+{
+ std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
+ const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
+
+ if (stretchDim != targetDimsIn.end())
+ {
+ if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
+ {
+ throw ParseException(
+ boost::str(
+ boost::format("At most one component of shape can be -1 %1%") % CHECK_LOCATION().AsString()));
+ }
+
+ auto targetNumElements =
+ boost::numeric_cast<unsigned int>(
+ std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
+
+ auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
+
+ TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
+
+ TensorInfo reshapeInfo = inputTensorInfo;
+ reshapeInfo.SetShape(outputShape);
+
+ return reshapeInfo;
+}
+
+void TfLiteParser::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 1);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto * options = operatorPtr->builtin_options.AsReshapeOptions();
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo outputTensorInfo =
+ TfLiteParser::OutputShapeOfReshape(inputTensorInfo, options->new_shape);
+
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
+
+ auto layerName = boost::str(boost::format("Reshape:%1%:%2%") % subgraphIndex % operatorIndex);
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
armnn::IConnectableLayer* TfLiteParser::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot,
tflite::ActivationFunctionType activationType)
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index b9f81e4118..f949484a4f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -74,7 +74,8 @@ public:
static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
const armnn::TensorInfo & inputTensorInfo);
-
+ static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
+ const std::vector<int32_t> & targetDimsIn);
private:
// No copying allowed until it is wanted and properly implemented
@@ -95,6 +96,7 @@ private:
void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
+ void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
diff --git a/src/armnnTfLiteParser/test/Reshape.cpp b/src/armnnTfLiteParser/test/Reshape.cpp
new file mode 100644
index 0000000000..ae5a09a711
--- /dev/null
+++ b/src/armnnTfLiteParser/test/Reshape.cpp
@@ -0,0 +1,141 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersFixture.hpp"
+#include "../TfLiteParser.hpp"
+
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
+
+struct ReshapeFixture : public ParserFlatbuffersFixture
+{
+ explicit ReshapeFixture(const std::string& inputShape,
+ const std::string& outputShape,
+ const std::string& newShape)
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "RESHAPE" } ],
+ "subgraphs": [ {
+ "tensors": [
+ {)";
+ m_JsonString += R"(
+ "shape" : )" + inputShape + ",";
+ m_JsonString += R"(
+ "type": "UINT8",
+ "buffer": 0,
+ "name": "inputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {)";
+ m_JsonString += R"(
+ "shape" : )" + outputShape;
+ m_JsonString += R"(,
+ "type": "UINT8",
+ "buffer": 1,
+ "name": "outputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ }
+ ],
+ "inputs": [ 0 ],
+ "outputs": [ 1 ],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [ 0 ],
+ "outputs": [ 1 ],
+ "builtin_options_type": "ReshapeOptions",
+ "builtin_options": {)";
+ if (!newShape.empty())
+ {
+ m_JsonString += R"("new_shape" : )" + newShape;
+ }
+ m_JsonString += R"(},
+ "custom_options_format": "FLEXBUFFERS"
+ }
+ ],
+ } ],
+ "buffers" : [ {}, {} ]
+ }
+ )";
+
+ }
+};
+
+struct ReshapeFixtureWithReshapeDims : ReshapeFixture
+{
+ ReshapeFixtureWithReshapeDims() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDims, ReshapeFixtureWithReshapeDims)
+{
+ SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+ RunTest<2, uint8_t>(0,
+ { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
+ { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+ == armnn::TensorShape({3,3})));
+}
+
+struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
+{
+ ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 1, 9 ]", "[ -1 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlatten, ReshapeFixtureWithReshapeDimsFlatten)
+{
+ SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+ RunTest<2, uint8_t>(0,
+ { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
+ { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+ == armnn::TensorShape({1,9})));
+}
+
+struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture
+{
+ ReshapeFixtureWithReshapeDimsFlattenTwoDims() : ReshapeFixture("[ 3, 2, 3 ]", "[ 2, 9 ]", "[ 2, -1 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlattenTwoDims, ReshapeFixtureWithReshapeDimsFlattenTwoDims)
+{
+ SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+ RunTest<2, uint8_t>(0,
+ { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
+ { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+ == armnn::TensorShape({2,9})));
+}
+
+struct ReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
+{
+ ReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]", "[ 2, 3, 3 ]", "[ 2, -1, 3 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlattenOneDim, ReshapeFixtureWithReshapeDimsFlattenOneDim)
+{
+ SetupSingleInputSingleOutput("inputTensor", "outputTensor");
+ RunTest<3, uint8_t>(0,
+ { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
+ { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+ == armnn::TensorShape({2,3,3})));
+}
+
+BOOST_AUTO_TEST_SUITE_END()