diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 53 |
1 files changed, 51 insertions, 2 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 6077d057c4..99ee0b5b2d 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -203,6 +203,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_DequantizeLayer] = &Deserializer::ParseDequantize; m_ParserFunctions[Layer_DetectionPostProcessLayer] = &Deserializer::ParseDetectionPostProcess; m_ParserFunctions[Layer_DivisionLayer] = &Deserializer::ParseDivision; + m_ParserFunctions[Layer_ElementwiseUnaryLayer] = &Deserializer::ParseElementwiseUnary; m_ParserFunctions[Layer_EqualLayer] = &Deserializer::ParseEqual; m_ParserFunctions[Layer_FullyConnectedLayer] = &Deserializer::ParseFullyConnected; m_ParserFunctions[Layer_FloorLayer] = &Deserializer::ParseFloor; @@ -457,6 +458,25 @@ armnn::ComparisonOperation ToComparisonOperation(armnnSerializer::ComparisonOper } } +armnn::UnaryOperation ToUnaryOperation(armnnSerializer::UnaryOperation operation) +{ + switch (operation) + { + case armnnSerializer::UnaryOperation::UnaryOperation_Abs: + return armnn::UnaryOperation::Abs; + case armnnSerializer::UnaryOperation::UnaryOperation_Rsqrt: + return armnn::UnaryOperation::Rsqrt; + case armnnSerializer::UnaryOperation::UnaryOperation_Sqrt: + return armnn::UnaryOperation::Sqrt; + case armnnSerializer::UnaryOperation::UnaryOperation_Exp: + return armnn::UnaryOperation::Exp; + case armnnSerializer::UnaryOperation::UnaryOperation_Neg: + return armnn::UnaryOperation::Neg; + default: + throw armnn::InvalidArgumentException("Unary operation unknown"); + } +} + armnn::ResizeMethod ToResizeMethod(armnnSerializer::ResizeMethod method) { switch (method) @@ -926,7 +946,8 @@ void Deserializer::ParseAbs(armnnDeserializer::Deserializer::GraphPtr graph, uns auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddAbsLayer(layerName.c_str()); + armnn::ElementwiseUnaryDescriptor descriptor(armnn::UnaryOperation::Abs); + IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); @@ -1496,6 +1517,33 @@ void Deserializer::ParseComparison(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + CHECK_LOCATION(); + + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto fbLayer = graph->layers()->Get(layerIndex)->layer_as_ElementwiseUnaryLayer(); + auto fbDescriptor = fbLayer->descriptor(); + + armnn::ElementwiseUnaryDescriptor descriptor; + descriptor.m_Operation = ToUnaryOperation(fbDescriptor->operation()); + + const std::string& layerName = GetLayerName(graph, layerIndex); + IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); @@ -2135,8 +2183,9 @@ void Deserializer::ParseRsqrt(GraphPtr graph, unsigned int layerIndex) CHECK_VALID_SIZE(outputs.size(), 1); auto layerName = GetLayerName(graph, layerIndex); - IConnectableLayer* layer = m_Network->AddRsqrtLayer(layerName.c_str()); + armnn::ElementwiseUnaryDescriptor descriptor(armnn::UnaryOperation::Rsqrt); + IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); |