aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializeParser
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializeParser')
-rw-r--r--src/armnnDeserializeParser/DeserializeParser.cpp100
-rw-r--r--src/armnnDeserializeParser/DeserializeParser.hpp3
-rw-r--r--src/armnnDeserializeParser/DeserializerSupport.md1
-rw-r--r--src/armnnDeserializeParser/test/DeserializeReshape.cpp128
4 files changed, 231 insertions, 1 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp
index f47c23f0b5..de9b1a98c7 100644
--- a/src/armnnDeserializeParser/DeserializeParser.cpp
+++ b/src/armnnDeserializeParser/DeserializeParser.cpp
@@ -23,6 +23,9 @@
#include <Schema_generated.h>
#include <fstream>
+#include <algorithm>
+#include <limits>
+#include <numeric>
using armnn::ParseException;
using namespace armnn;
@@ -128,6 +131,25 @@ void CheckTensorPtr(DeserializeParser::TensorRawPtr rawPtr,
CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION())
}
+bool CheckShape(const armnn::TensorShape& actual, const std::vector<uint32_t>& expected)
+{
+ const unsigned int actualSize = actual.GetNumDimensions();
+ if (actualSize != expected.size())
+ {
+ return false;
+ }
+
+ for (unsigned int i = 0u; i < actualSize; i++)
+ {
+ if (actual[i] != static_cast<unsigned int>(expected[i]))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
DeserializeParser::DeserializeParser()
: m_Network(nullptr, nullptr),
//May require LayerType_Max to be included
@@ -137,6 +159,7 @@ m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer)
m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd;
m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication;
m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializeParser::ParsePooling2d;
+ m_ParserFunctions[Layer_ReshapeLayer] = &DeserializeParser::ParseReshape;
m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax;
}
@@ -156,6 +179,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
case Layer::Layer_Pooling2dLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
+ case Layer::Layer_ReshapeLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base();
case Layer::Layer_SoftmaxLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
case Layer::Layer_NONE:
@@ -247,12 +272,12 @@ DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(cons
{
CHECK_GRAPH(graphPtr, 0);
const auto& numOutputs = graphPtr->outputIds()->size();
-
LayerBaseRawPtrVector result(numOutputs);
for (unsigned int i=0; i<numOutputs; ++i)
{
uint32_t outputId = graphPtr->outputIds()->Get(i);
+
result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId));
}
return result;
@@ -726,6 +751,79 @@ void DeserializeParser::ParsePooling2d(unsigned int layerIndex)
RegisterOutputSlots(layerIndex, layer);
}
+armnn::TensorInfo DeserializeParser::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
+ const std::vector<uint32_t>& targetDimsIn)
+{
+ std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
+ const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
+
+ if (stretchDim != targetDimsIn.end())
+ {
+ if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
+ {
+ throw ParseException(boost::str(
+ boost::format("At most one component of shape can be -1 %1%") % CHECK_LOCATION().AsString()));
+ }
+
+ auto targetNumElements =
+ boost::numeric_cast<unsigned int>(
+ std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
+
+ auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
+
+ TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
+
+ armnn::TensorInfo reshapeInfo = inputTensorInfo;
+ reshapeInfo.SetShape(outputShape);
+
+ return reshapeInfo;
+}
+
+void DeserializeParser::ParseReshape(unsigned int layerIndex)
+{
+ CHECK_LAYERS(m_Graph, 0, layerIndex);
+ auto inputs = GetInputs(m_Graph, layerIndex);
+
+ auto outputs = GetOutputs(m_Graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
+
+ const auto targetDims = m_Graph->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->descriptor()->targetShape();
+ std::vector<uint32_t> outputDims(targetDims->begin(), targetDims->begin() + targetDims->size());
+
+ armnn::TensorInfo reshapeOutputTensorInfo = DeserializeParser::OutputShapeOfReshape(inputTensorInfo, outputDims);
+ const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
+
+ const std::vector<uint32_t> expectedDims(outputs[0]->dimensions()->begin(),
+ outputs[0]->dimensions()->begin() + outputs[0]->dimensions()->size());
+
+ if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, expectedDims))
+ {
+ std::stringstream ss;
+ ss << "New shape defined in reshape parameters "
+ << reshapeOutputTensorShape
+ << " does not equal output shape "
+ << actualOutputTensorInfo.GetShape()
+ << ": "
+ << CHECK_LOCATION().AsString();
+ throw ParseException(ss.str());
+ }
+
+ armnn::ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = reshapeOutputTensorShape;
+
+ auto layerName = boost::str(boost::format("Reshape:%1%") % layerIndex);
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
+
+ RegisterInputSlots(layerIndex, layer);
+ RegisterOutputSlots(layerIndex, layer);
+}
+
void DeserializeParser::ParseSoftmax(unsigned int layerIndex)
{
CHECK_LAYERS(m_Graph, 0, layerIndex);
diff --git a/src/armnnDeserializeParser/DeserializeParser.hpp b/src/armnnDeserializeParser/DeserializeParser.hpp
index 1edb5a9f23..666cbca33c 100644
--- a/src/armnnDeserializeParser/DeserializeParser.hpp
+++ b/src/armnnDeserializeParser/DeserializeParser.hpp
@@ -53,6 +53,8 @@ public:
static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
unsigned int layerIndex);
+ static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
+ const std::vector<uint32_t> & targetDimsIn);
private:
// No copying allowed until it is wanted and properly implemented
@@ -69,6 +71,7 @@ private:
void ParseAdd(unsigned int layerIndex);
void ParseMultiplication(unsigned int layerIndex);
void ParsePooling2d(unsigned int layerIndex);
+ void ParseReshape(unsigned int layerIndex);
void ParseSoftmax(unsigned int layerIndex);
void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
diff --git a/src/armnnDeserializeParser/DeserializerSupport.md b/src/armnnDeserializeParser/DeserializerSupport.md
index d4925cc0ad..c03471af75 100644
--- a/src/armnnDeserializeParser/DeserializerSupport.md
+++ b/src/armnnDeserializeParser/DeserializerSupport.md
@@ -8,6 +8,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Addition
* Multiplication
+* Reshape
* Softmax
More machine learning layers will be supported in future releases.
diff --git a/src/armnnDeserializeParser/test/DeserializeReshape.cpp b/src/armnnDeserializeParser/test/DeserializeReshape.cpp
new file mode 100644
index 0000000000..21e60933f6
--- /dev/null
+++ b/src/armnnDeserializeParser/test/DeserializeReshape.cpp
@@ -0,0 +1,128 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "ParserFlatbuffersSerializeFixture.hpp"
+#include "../DeserializeParser.hpp"
+
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(DeserializeParser)
+
+struct ReshapeFixture : public ParserFlatbuffersSerializeFixture
+{
+ explicit ReshapeFixture(const std::string &inputShape,
+ const std::string &targetShape,
+ const std::string &outputShape,
+ const std::string &dataType)
+ {
+ m_JsonString = R"(
+ {
+ inputIds: [0],
+ outputIds: [2],
+ layers: [
+ {
+ layer_type: "InputLayer",
+ layer: {
+ base: {
+ layerBindingId: 0,
+ base: {
+ index: 0,
+ layerName: "InputLayer",
+ layerType: "Input",
+ inputSlots: [{
+ index: 0,
+ connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+ }],
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + inputShape + R"(,
+ dataType: )" + dataType + R"(
+ }}]
+ }
+ }}},
+ {
+ layer_type: "ReshapeLayer",
+ layer: {
+ base: {
+ index: 1,
+ layerName: "ReshapeLayer",
+ layerType: "Reshape",
+ inputSlots: [{
+ index: 0,
+ connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+ }],
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + inputShape + R"(,
+ dataType: )" + dataType + R"(
+
+ }}]},
+ descriptor: {
+ targetShape: )" + targetShape + R"(,
+ }
+
+ }},
+ {
+ layer_type: "OutputLayer",
+ layer: {
+ base:{
+ layerBindingId: 2,
+ base: {
+ index: 2,
+ layerName: "OutputLayer",
+ layerType: "Output",
+ inputSlots: [{
+ index: 0,
+ connection: {sourceLayerIndex:0, outputSlotIndex:0 },
+ }],
+ outputSlots: [ {
+ index: 0,
+ tensorInfo: {
+ dimensions: )" + outputShape + R"(,
+ dataType: )" + dataType + R"(
+ },
+ }],
+ }}},
+ }]
+ }
+ )";
+ SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
+ }
+};
+
+struct SimpleReshapeFixture : ReshapeFixture
+{
+ SimpleReshapeFixture() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]",
+ "QuantisedAsymm8") {}
+};
+
+struct SimpleReshapeFixture2 : ReshapeFixture
+{
+ SimpleReshapeFixture2() : ReshapeFixture("[ 2, 2, 1, 1 ]",
+ "[ 2, 2, 1, 1 ]",
+ "[ 2, 2, 1, 1 ]",
+ "Float32") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ReshapeQuantisedAsymm8, SimpleReshapeFixture)
+{
+ RunTest<2, armnn::DataType::QuantisedAsymm8>(0,
+ { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
+ { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
+}
+
+BOOST_FIXTURE_TEST_CASE(ReshapeFloat32, SimpleReshapeFixture2)
+{
+ RunTest<4, armnn::DataType::Float32>(0,
+ { 111, 85, 226, 3 },
+ { 111, 85, 226, 3 });
+}
+
+
+BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file