diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 75 |
1 files changed, 66 insertions, 9 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 702b060512..ed921880e0 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017,2019-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -229,6 +229,7 @@ m_ParserFunctions(Layer_MAX+1, &IDeserializer::DeserializerImpl::ParseUnsupporte m_ParserFunctions[Layer_DequantizeLayer] = &DeserializerImpl::ParseDequantize; m_ParserFunctions[Layer_DetectionPostProcessLayer] = &DeserializerImpl::ParseDetectionPostProcess; m_ParserFunctions[Layer_DivisionLayer] = &DeserializerImpl::ParseDivision; + m_ParserFunctions[Layer_ElementwiseBinaryLayer] = &DeserializerImpl::ParseElementwiseBinary; m_ParserFunctions[Layer_ElementwiseUnaryLayer] = &DeserializerImpl::ParseElementwiseUnary; m_ParserFunctions[Layer_EqualLayer] = &DeserializerImpl::ParseEqual; m_ParserFunctions[Layer_FullyConnectedLayer] = &DeserializerImpl::ParseFullyConnected; @@ -325,6 +326,8 @@ LayerBaseRawPtr IDeserializer::DeserializerImpl::GetBaseLayer(const GraphPtr& gr return graphPtr->layers()->Get(layerIndex)->layer_as_DivisionLayer()->base(); case Layer::Layer_EqualLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_EqualLayer()->base(); + case Layer::Layer_ElementwiseBinaryLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseBinaryLayer()->base(); case Layer::Layer_ElementwiseUnaryLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_ElementwiseUnaryLayer()->base(); case Layer::Layer_FullyConnectedLayer: @@ -562,7 +565,28 @@ armnn::LogicalBinaryOperation ToLogicalBinaryOperation(armnnSerializer::LogicalB } } -armnn::UnaryOperation ToUnaryOperation(armnnSerializer::UnaryOperation operation) +armnn::BinaryOperation ToElementwiseBinaryOperation(armnnSerializer::BinaryOperation operation) +{ + switch (operation) + { + case armnnSerializer::BinaryOperation::BinaryOperation_Add: + return armnn::BinaryOperation::Add; + case armnnSerializer::BinaryOperation::BinaryOperation_Div: + return armnn::BinaryOperation::Div; + case armnnSerializer::BinaryOperation::BinaryOperation_Maximum: + return armnn::BinaryOperation::Maximum; + case armnnSerializer::BinaryOperation::BinaryOperation_Minimum: + return armnn::BinaryOperation::Minimum; + case armnnSerializer::BinaryOperation::BinaryOperation_Mul: + return armnn::BinaryOperation::Mul; + case armnnSerializer::BinaryOperation::BinaryOperation_Sub: + return armnn::BinaryOperation::Sub; + default: + throw armnn::InvalidArgumentException("Binary operation unknown"); + } +} + +armnn::UnaryOperation ToElementwiseUnaryOperation(armnnSerializer::UnaryOperation operation) { switch (operation) { @@ -1226,7 +1250,8 @@ void IDeserializer::DeserializerImpl::ParseAdd(GraphPtr graph, unsigned int laye CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str()); + armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Add); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); @@ -1745,7 +1770,8 @@ void IDeserializer::DeserializerImpl::ParseDivision(GraphPtr graph, unsigned int CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddDivisionLayer(layerName.c_str()); + armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Div); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); @@ -1935,7 +1961,8 @@ void IDeserializer::DeserializerImpl::ParseMinimum(GraphPtr graph, unsigned int CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddMinimumLayer(layerName.c_str()); + armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Minimum); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); @@ -1955,7 +1982,8 @@ void IDeserializer::DeserializerImpl::ParseMaximum(GraphPtr graph, unsigned int CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddMaximumLayer(layerName.c_str()); + armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Maximum); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); @@ -2030,6 +2058,33 @@ void IDeserializer::DeserializerImpl::ParseComparison(GraphPtr graph, unsigned i RegisterOutputSlots(graph, layerIndex, layer); } +void IDeserializer::DeserializerImpl::ParseElementwiseBinary(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + CHECK_LOCATION(); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ElementwiseBinaryLayer(); + auto fbDescriptor = fbLayer->descriptor(); + + armnn::ElementwiseBinaryDescriptor descriptor; + descriptor.m_Operation = ToElementwiseBinaryOperation(fbDescriptor->operation()); + + const std::string& layerName = GetLayerName(graph, layerIndex); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void IDeserializer::DeserializerImpl::ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); @@ -2045,7 +2100,7 @@ void IDeserializer::DeserializerImpl::ParseElementwiseUnary(GraphPtr graph, unsi auto fbDescriptor = fbLayer->descriptor(); armnn::ElementwiseUnaryDescriptor descriptor; - descriptor.m_Operation = ToUnaryOperation(fbDescriptor->operation()); + descriptor.m_Operation = ToElementwiseUnaryOperation(fbDescriptor->operation()); const std::string& layerName = GetLayerName(graph, layerIndex); IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str()); @@ -2106,7 +2161,8 @@ void IDeserializer::DeserializerImpl::ParseMultiplication(GraphPtr graph, unsign CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str()); + armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Mul); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); @@ -3023,7 +3079,8 @@ void IDeserializer::DeserializerImpl::ParseSubtraction(GraphPtr graph, unsigned CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddSubtractionLayer(layerName.c_str()); + armnn::ElementwiseBinaryDescriptor descriptor(armnn::BinaryOperation::Sub); + IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); |