aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializeParser/DeserializeParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializeParser/DeserializeParser.cpp')
-rw-r--r--src/armnnDeserializeParser/DeserializeParser.cpp100
1 files changed, 99 insertions, 1 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp
index f47c23f0b5..de9b1a98c7 100644
--- a/src/armnnDeserializeParser/DeserializeParser.cpp
+++ b/src/armnnDeserializeParser/DeserializeParser.cpp
@@ -23,6 +23,9 @@
#include <Schema_generated.h>
#include <fstream>
+#include <algorithm>
+#include <limits>
+#include <numeric>
using armnn::ParseException;
using namespace armnn;
@@ -128,6 +131,25 @@ void CheckTensorPtr(DeserializeParser::TensorRawPtr rawPtr,
CheckGraph(GRAPH, LAYERS_INDEX, CHECK_LOCATION())
}
+bool CheckShape(const armnn::TensorShape& actual, const std::vector<uint32_t>& expected)
+{
+ const unsigned int actualSize = actual.GetNumDimensions();
+ if (actualSize != expected.size())
+ {
+ return false;
+ }
+
+ for (unsigned int i = 0u; i < actualSize; i++)
+ {
+ if (actual[i] != static_cast<unsigned int>(expected[i]))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+
DeserializeParser::DeserializeParser()
: m_Network(nullptr, nullptr),
//May require LayerType_Max to be included
@@ -137,6 +159,7 @@ m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer)
m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd;
m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication;
m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializeParser::ParsePooling2d;
+ m_ParserFunctions[Layer_ReshapeLayer] = &DeserializeParser::ParseReshape;
m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax;
}
@@ -156,6 +179,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
case Layer::Layer_Pooling2dLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
+ case Layer::Layer_ReshapeLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->base();
case Layer::Layer_SoftmaxLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
case Layer::Layer_NONE:
@@ -247,12 +272,12 @@ DeserializeParser::LayerBaseRawPtrVector DeserializeParser::GetGraphOutputs(cons
{
CHECK_GRAPH(graphPtr, 0);
const auto& numOutputs = graphPtr->outputIds()->size();
-
LayerBaseRawPtrVector result(numOutputs);
for (unsigned int i=0; i<numOutputs; ++i)
{
uint32_t outputId = graphPtr->outputIds()->Get(i);
+
result[i] = GetBaseLayer(graphPtr, static_cast<uint32_t>(outputId));
}
return result;
@@ -726,6 +751,79 @@ void DeserializeParser::ParsePooling2d(unsigned int layerIndex)
RegisterOutputSlots(layerIndex, layer);
}
+armnn::TensorInfo DeserializeParser::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
+ const std::vector<uint32_t>& targetDimsIn)
+{
+ std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
+ const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
+
+ if (stretchDim != targetDimsIn.end())
+ {
+ if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
+ {
+ throw ParseException(boost::str(
+ boost::format("At most one component of shape can be -1 %1%") % CHECK_LOCATION().AsString()));
+ }
+
+ auto targetNumElements =
+ boost::numeric_cast<unsigned int>(
+ std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
+
+ auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
+
+ TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
+
+ armnn::TensorInfo reshapeInfo = inputTensorInfo;
+ reshapeInfo.SetShape(outputShape);
+
+ return reshapeInfo;
+}
+
+void DeserializeParser::ParseReshape(unsigned int layerIndex)
+{
+ CHECK_LAYERS(m_Graph, 0, layerIndex);
+ auto inputs = GetInputs(m_Graph, layerIndex);
+
+ auto outputs = GetOutputs(m_Graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+ armnn::TensorInfo actualOutputTensorInfo = ToTensorInfo(outputs[0]);
+
+ const auto targetDims = m_Graph->layers()->Get(layerIndex)->layer_as_ReshapeLayer()->descriptor()->targetShape();
+ std::vector<uint32_t> outputDims(targetDims->begin(), targetDims->begin() + targetDims->size());
+
+ armnn::TensorInfo reshapeOutputTensorInfo = DeserializeParser::OutputShapeOfReshape(inputTensorInfo, outputDims);
+ const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
+
+ const std::vector<uint32_t> expectedDims(outputs[0]->dimensions()->begin(),
+ outputs[0]->dimensions()->begin() + outputs[0]->dimensions()->size());
+
+ if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, expectedDims))
+ {
+ std::stringstream ss;
+ ss << "New shape defined in reshape parameters "
+ << reshapeOutputTensorShape
+ << " does not equal output shape "
+ << actualOutputTensorInfo.GetShape()
+ << ": "
+ << CHECK_LOCATION().AsString();
+ throw ParseException(ss.str());
+ }
+
+ armnn::ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = reshapeOutputTensorShape;
+
+ auto layerName = boost::str(boost::format("Reshape:%1%") % layerIndex);
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
+
+ RegisterInputSlots(layerIndex, layer);
+ RegisterOutputSlots(layerIndex, layer);
+}
+
void DeserializeParser::ParseSoftmax(unsigned int layerIndex)
{
CHECK_LAYERS(m_Graph, 0, layerIndex);