diff options
Diffstat (limited to 'src/armnnDeserializeParser')
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 100 | ||||
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.hpp | 3 | ||||
-rw-r--r-- | src/armnnDeserializeParser/DeserializerSupport.md | 1 | ||||
-rw-r--r-- | src/armnnDeserializeParser/test/DeserializeReshape.cpp | 128 |
4 files changed, 231 insertions, 1 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index f47c23f0b5..de9b1a98c7 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -23,6 +23,9 @@ #include <Schema_generated.h> #include <fstream> +#include <algorithm> +#include <limits> +#include <numeric> using armnn::ParseException; using namespace armnn; @@ -128,6 +131,25 @@ void CheckTensorPtr(DeserializeParser::TensorRawPtr rawPtr, CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION()) } +bool CheckShape(const armnn::TensorShape& actual, const std::vector<uint32_t>& expected) +{ + const unsigned int actualSize = actual.GetNumDimensions(); + if (actualSize != expected.size()) + { + return false; + } + + for (unsigned int i = 0u; i < actualSize; i++) + { + if (actual[i] != static_cast<unsigned int>(expected[i])) + { + return false; + } + } + + return true; +} + DeserializeParser::DeserializeParser() : m_Network(nullptr, nullptr), //May require LayerType_Max to be included @@ -137,6 +159,7 @@ m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer) m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializeParser::ParsePooling2d; + m_ParserFunctions[Layer_ReshapeLayer] = &DeserializeParser::ParseReshape; m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax; } @@ -156,6 +179,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base(); case Layer::Layer_Pooling2dLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base(); + case Layer::Layer_ReshapeLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base(); case Layer::Layer_SoftmaxLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base(); case Layer::Layer_NONE: @@ -247,12 +272,12 @@ DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(cons { CHECK_GRAPH(graphPtr, 0); const auto& numOutputs = graphPtr->outputIds()->size(); - LayerBaseRawPtrVector result(numOutputs); for (unsigned int i=0; i<numOutputs; ++i) { uint32_t outputId = graphPtr->outputIds()->Get(i); + result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId)); } return result; @@ -726,6 +751,79 @@ void DeserializeParser::ParsePooling2d(unsigned int layerIndex) RegisterOutputSlots(layerIndex, layer); } +armnn::TensorInfo DeserializeParser::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, + const std::vector<uint32_t>& targetDimsIn) +{ + std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end()); + const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1); + + if (stretchDim != targetDimsIn.end()) + { + if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end()) + { + throw ParseException(boost::str( + boost::format("At most one component of shape can be -1 %1%") % CHECK_LOCATION().AsString())); + } + + auto targetNumElements = + boost::numeric_cast<unsigned int>( + std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>())); + + auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim)); + outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + } + + TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data()); + + armnn::TensorInfo reshapeInfo = inputTensorInfo; + reshapeInfo.SetShape(outputShape); + + return reshapeInfo; +} + +void DeserializeParser::ParseReshape(unsigned int layerIndex) +{ + CHECK_LAYERS(m_Graph, 0, layerIndex); + auto inputs = GetInputs(m_Graph, layerIndex); + + auto outputs = GetOutputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); + + const auto targetDims = m_Graph->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->descriptor()->targetShape(); + std::vector<uint32_t> outputDims(targetDims->begin(), targetDims->begin() + targetDims->size()); + + armnn::TensorInfo reshapeOutputTensorInfo = DeserializeParser::OutputShapeOfReshape(inputTensorInfo, outputDims); + const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); + + const std::vector<uint32_t> expectedDims(outputs[0]->dimensions()->begin(), + outputs[0]->dimensions()->begin() + outputs[0]->dimensions()->size()); + + if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, expectedDims)) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorShape + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } + + armnn::ReshapeDescriptor reshapeDesc; + reshapeDesc.m_TargetShape = reshapeOutputTensorShape; + + auto layerName = boost::str(boost::format("Reshape:%1%") % layerIndex); + IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo); + + RegisterInputSlots(layerIndex, layer); + RegisterOutputSlots(layerIndex, layer); +} + void DeserializeParser::ParseSoftmax(unsigned int layerIndex) { CHECK_LAYERS(m_Graph, 0, layerIndex); diff --git a/src/armnnDeserializeParser/DeserializeParser.hpp b/src/armnnDeserializeParser/DeserializeParser.hpp index 1edb5a9f23..666cbca33c 100644 --- a/src/armnnDeserializeParser/DeserializeParser.hpp +++ b/src/armnnDeserializeParser/DeserializeParser.hpp @@ -53,6 +53,8 @@ public: static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex); armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor, unsigned int layerIndex); + static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo, + const std::vector<uint32_t> & targetDimsIn); private: // No copying allowed until it is wanted and properly implemented @@ -69,6 +71,7 @@ private: void ParseAdd(unsigned int layerIndex); void ParseMultiplication(unsigned int layerIndex); void ParsePooling2d(unsigned int layerIndex); + void ParseReshape(unsigned int layerIndex); void ParseSoftmax(unsigned int layerIndex); void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot); diff --git a/src/armnnDeserializeParser/DeserializerSupport.md b/src/armnnDeserializeParser/DeserializerSupport.md index d4925cc0ad..c03471af75 100644 --- a/src/armnnDeserializeParser/DeserializerSupport.md +++ b/src/armnnDeserializeParser/DeserializerSupport.md @@ -8,6 +8,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Addition * Multiplication +* Reshape * Softmax More machine learning layers will be supported in future releases. diff --git a/src/armnnDeserializeParser/test/DeserializeReshape.cpp b/src/armnnDeserializeParser/test/DeserializeReshape.cpp new file mode 100644 index 0000000000..21e60933f6 --- /dev/null +++ b/src/armnnDeserializeParser/test/DeserializeReshape.cpp @@ -0,0 +1,128 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <boost/test/unit_test.hpp> +#include "ParserFlatbuffersSerializeFixture.hpp" +#include "../DeserializeParser.hpp" + +#include <string> +#include <iostream> + +BOOST_AUTO_TEST_SUITE(DeserializeParser) + +struct ReshapeFixture : public ParserFlatbuffersSerializeFixture +{ + explicit ReshapeFixture(const std::string &inputShape, + const std::string &targetShape, + const std::string &outputShape, + const std::string &dataType) + { + m_JsonString = R"( + { + inputIds: [0], + outputIds: [2], + layers: [ + { + layer_type: "InputLayer", + layer: { + base: { + layerBindingId: 0, + base: { + index: 0, + layerName: "InputLayer", + layerType: "Input", + inputSlots: [{ + index: 0, + connection: {sourceLayerIndex:0, outputSlotIndex:0 }, + }], + outputSlots: [ { + index: 0, + tensorInfo: { + dimensions: )" + inputShape + R"(, + dataType: )" + dataType + R"( + }}] + } + }}}, + { + layer_type: "ReshapeLayer", + layer: { + base: { + index: 1, + layerName: "ReshapeLayer", + layerType: "Reshape", + inputSlots: [{ + index: 0, + connection: {sourceLayerIndex:0, outputSlotIndex:0 }, + }], + outputSlots: [ { + index: 0, + tensorInfo: { + dimensions: )" + inputShape + R"(, + dataType: )" + dataType + R"( + + }}]}, + descriptor: { + targetShape: )" + targetShape + R"(, + } + + }}, + { + layer_type: "OutputLayer", + layer: { + base:{ + layerBindingId: 2, + base: { + index: 2, + layerName: "OutputLayer", + layerType: "Output", + inputSlots: [{ + index: 0, + connection: {sourceLayerIndex:0, outputSlotIndex:0 }, + }], + outputSlots: [ { + index: 0, + tensorInfo: { + dimensions: )" + outputShape + R"(, + dataType: )" + dataType + R"( + }, + }], + }}}, + }] + } + )"; + SetupSingleInputSingleOutput("InputLayer", "OutputLayer"); + } +}; + +struct SimpleReshapeFixture : ReshapeFixture +{ + SimpleReshapeFixture() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]", + "QuantisedAsymm8") {} +}; + +struct SimpleReshapeFixture2 : ReshapeFixture +{ + SimpleReshapeFixture2() : ReshapeFixture("[ 2, 2, 1, 1 ]", + "[ 2, 2, 1, 1 ]", + "[ 2, 2, 1, 1 ]", + "Float32") {} +}; + +BOOST_FIXTURE_TEST_CASE(ReshapeQuantisedAsymm8, SimpleReshapeFixture) +{ + RunTest<2, armnn::DataType::QuantisedAsymm8>(0, + { 1, 2, 3, 4, 5, 6, 7, 8, 9 }, + { 1, 2, 3, 4, 5, 6, 7, 8, 9 }); +} + +BOOST_FIXTURE_TEST_CASE(ReshapeFloat32, SimpleReshapeFixture2) +{ + RunTest<4, armnn::DataType::Float32>(0, + { 111, 85, 226, 3 }, + { 111, 85, 226, 3 }); +} + + +BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file |