diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 61a38f9cf3..0d81649115 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -12,6 +12,7 @@ #include <armnn/QuantizedLstmParams.hpp> #include <armnnUtils/Permute.hpp> +#include <armnnUtils/Transpose.hpp> #include <ParserHelper.hpp> #include <VerificationHelpers.hpp> @@ -241,6 +242,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_SubtractionLayer] = &Deserializer::ParseSubtraction; m_ParserFunctions[Layer_SwitchLayer] = &Deserializer::ParseSwitch; m_ParserFunctions[Layer_TransposeConvolution2dLayer] = &Deserializer::ParseTransposeConvolution2d; + m_ParserFunctions[Layer_TransposeLayer] = &Deserializer::ParseTranspose; } Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -357,6 +359,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base(); case Layer::Layer_TransposeConvolution2dLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeConvolution2dLayer()->base(); + case Layer::Layer_TransposeLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_TransposeLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(boost::str( @@ -2721,6 +2725,29 @@ void Deserializer::ParsePrelu(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseTranspose(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + auto dimsMapping = graph->layers()->Get(layerIndex)->layer_as_TransposeLayer()->descriptor()->dimMappings(); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + auto outputInfo = ToTensorInfo(outputs[0]); + + auto layerName = GetLayerName(graph, layerIndex); + const armnn::TransposeDescriptor descriptor(armnn::PermutationVector(dimsMapping->data(), dimsMapping->Length())); + + IConnectableLayer* layer = m_Network->AddTransposeLayer(descriptor, layerName.c_str()); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); |