diff options
author | Saoirse Stewart <saoirse.stewart@arm.com> | 2019-02-19 15:54:14 +0000 |
---|---|---|
committer | Saoirse Stewart Arm <saoirse.stewart@arm.com> | 2019-02-19 15:55:46 +0000 |
commit | 263829c2163d79a28f98f24f9dd1e52e1c3cbbef (patch) | |
tree | bd904ce4b8aeaa14bc0622bbacefda26011733f2 /src/armnnDeserializeParser/DeserializeParser.cpp | |
parent | 4fbae33571871ce584e421657e8ffba299e89d67 (diff) | |
download | armnn-263829c2163d79a28f98f24f9dd1e52e1c3cbbef.tar.gz |
IVGCVSW-2642 Add Reshape to Serializer and Deserializer
Change-Id: Iccded3c6e3c0713c44f43231981440420591f94e
Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Diffstat (limited to 'src/armnnDeserializeParser/DeserializeParser.cpp')
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 100 |
1 files changed, 99 insertions, 1 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index f47c23f0b5..de9b1a98c7 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -23,6 +23,9 @@ #include <Schema_generated.h> #include <fstream> +#include <algorithm> +#include <limits> +#include <numeric> using armnn::ParseException; using namespace armnn; @@ -128,6 +131,25 @@ void CheckTensorPtr(DeserializeParser::TensorRawPtr rawPtr, CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION()) } +bool CheckShape(const armnn::TensorShape& actual, const std::vector<uint32_t>& expected) +{ + const unsigned int actualSize = actual.GetNumDimensions(); + if (actualSize != expected.size()) + { + return false; + } + + for (unsigned int i = 0u; i < actualSize; i++) + { + if (actual[i] != static_cast<unsigned int>(expected[i])) + { + return false; + } + } + + return true; +} + DeserializeParser::DeserializeParser() : m_Network(nullptr, nullptr), //May require LayerType_Max to be included @@ -137,6 +159,7 @@ m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer) m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializeParser::ParsePooling2d; + m_ParserFunctions[Layer_ReshapeLayer] = &DeserializeParser::ParseReshape; m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax; } @@ -156,6 +179,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base(); case Layer::Layer_Pooling2dLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base(); + case Layer::Layer_ReshapeLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base(); case Layer::Layer_SoftmaxLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base(); case Layer::Layer_NONE: @@ -247,12 +272,12 @@ DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(cons { CHECK_GRAPH(graphPtr, 0); const auto& numOutputs = graphPtr->outputIds()->size(); - LayerBaseRawPtrVector result(numOutputs); for (unsigned int i=0; i<numOutputs; ++i) { uint32_t outputId = graphPtr->outputIds()->Get(i); + result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId)); } return result; @@ -726,6 +751,79 @@ void DeserializeParser::ParsePooling2d(unsigned int layerIndex) RegisterOutputSlots(layerIndex, layer); } +armnn::TensorInfo DeserializeParser::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo, + const std::vector<uint32_t>& targetDimsIn) +{ + std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end()); + const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1); + + if (stretchDim != targetDimsIn.end()) + { + if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end()) + { + throw ParseException(boost::str( + boost::format("At most one component of shape can be -1 %1%") % CHECK_LOCATION().AsString())); + } + + auto targetNumElements = + boost::numeric_cast<unsigned int>( + std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>())); + + auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim)); + outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements; + } + + TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data()); + + armnn::TensorInfo reshapeInfo = inputTensorInfo; + reshapeInfo.SetShape(outputShape); + + return reshapeInfo; +} + +void DeserializeParser::ParseReshape(unsigned int layerIndex) +{ + CHECK_LAYERS(m_Graph, 0, layerIndex); + auto inputs = GetInputs(m_Graph, layerIndex); + + auto outputs = GetOutputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]); + + const auto targetDims = m_Graph->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->descriptor()->targetShape(); + std::vector<uint32_t> outputDims(targetDims->begin(), targetDims->begin() + targetDims->size()); + + armnn::TensorInfo reshapeOutputTensorInfo = DeserializeParser::OutputShapeOfReshape(inputTensorInfo, outputDims); + const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape(); + + const std::vector<uint32_t> expectedDims(outputs[0]->dimensions()->begin(), + outputs[0]->dimensions()->begin() + outputs[0]->dimensions()->size()); + + if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, expectedDims)) + { + std::stringstream ss; + ss << "New shape defined in reshape parameters " + << reshapeOutputTensorShape + << " does not equal output shape " + << actualOutputTensorInfo.GetShape() + << ": " + << CHECK_LOCATION().AsString(); + throw ParseException(ss.str()); + } + + armnn::ReshapeDescriptor reshapeDesc; + reshapeDesc.m_TargetShape = reshapeOutputTensorShape; + + auto layerName = boost::str(boost::format("Reshape:%1%") % layerIndex); + IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo); + + RegisterInputSlots(layerIndex, layer); + RegisterOutputSlots(layerIndex, layer); +} + void DeserializeParser::ParseSoftmax(unsigned int layerIndex) { CHECK_LAYERS(m_Graph, 0, layerIndex); |