diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-02-12 14:31:45 +0000 |
---|---|---|
committer | Saoirse Stewart Arm <saoirse.stewart@arm.com> | 2019-02-13 09:10:35 +0000 |
commit | 5f45027909bba9f4abeeef6d8a265ed345d564ae (patch) | |
tree | 01bf69ba9db252a99b4dc6feccb20b08aabb7d70 /src/armnnDeserializeParser/DeserializeParser.cpp | |
parent | fb1437e86d8e01af9ee9cebe4c8cd9ff508ac779 (diff) | |
download | armnn-5f45027909bba9f4abeeef6d8a265ed345d564ae.tar.gz |
IVGCVSW-2640 Add Serializer & Deserializer for Mul
* Updated Serializer schema for Multiplication support
* Added support for Multiplication to Serializer and Deserializer
Change-Id: I10ad8ad4b37876a963ccdcf7074cb66f40531bde
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'src/armnnDeserializeParser/DeserializeParser.cpp')
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index 5ba92d51e2..0368ccabf2 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -132,7 +132,8 @@ DeserializeParser::DeserializeParser() m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer) { // register supported layers - m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; + m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; + m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; } DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -145,6 +146,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt return graphPtr->layers()->Get(layerIndex)->layer_as_AdditionLayer()->base(); case Layer::Layer_InputLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_InputLayer()->base()->base(); + case Layer::Layer_MultiplicationLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base(); case Layer::Layer_OutputLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base(); case Layer::Layer_NONE: @@ -582,6 +585,26 @@ void DeserializeParser::ParseAdd(unsigned int layerIndex) RegisterOutputSlots(layerIndex, layer); } +void DeserializeParser::ParseMultiplication(unsigned int layerIndex) +{ + CHECK_LAYERS(m_Graph, 0, layerIndex); + auto inputs = GetInputs(m_Graph, layerIndex); + CHECK_LOCATION(); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto layerName = boost::str(boost::format("Multiplication:%1%") % layerIndex); + IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(layerIndex, layer); + RegisterOutputSlots(layerIndex, layer); +} + } |