From 17ffff3f6708340695ca1433ed8b61955e15d7a5 Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Thu, 27 Jun 2019 14:12:55 +0100 Subject: IVGCVSW-3369 Add conversion method to HAL1.2 Policy for PReLU activation * Added ConvertPrelu method to HalPolicy V1.2 Signed-off-by: Matteo Martincigh Change-Id: I8248d2e2bd236295845da665b2d4e687478368ef --- 1.2/HalPolicy.cpp | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1.2/HalPolicy.hpp | 3 ++- Android.mk | 8 ++++---- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index b3b1d69f..4e638cf4 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -121,19 +121,23 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, return hal_1_0::HalPolicy::ConvertOperation(v10Operation, v10Model, data); } - else if (HandledByV1_1(operation) && compliantWithV1_1(model)) + + if (HandledByV1_1(operation) && compliantWithV1_1(model)) { hal_1_1::HalPolicy::Operation v11Operation = ConvertToV1_1(operation); hal_1_1::HalPolicy::Model v11Model = convertToV1_1(model); return hal_1_1::HalPolicy::ConvertOperation(v11Operation, v11Model, data); } + switch (operation.type) { case V1_2::OperationType::CONV_2D: return ConvertConv2d(operation, model, data); case V1_2::OperationType::DEPTHWISE_CONV_2D: return ConvertDepthwiseConv2d(operation, model, data); + case V1_2::OperationType::PRELU: + return ConvertPrelu(operation, model, data); default: return Fail("%s: Operation type %s not supported in ArmnnDriver", __func__, toString(operation.type).c_str()); @@ -418,5 +422,49 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& return SetupAndTrackLayerOutputSlot(operation, 0, *endLayer, model, data); } +bool HalPolicy::ConvertPrelu(const Operation& operation, const Model& model, ConversionData& data) +{ + LayerInputHandle input = ConvertToLayerInputHandle(operation, 0, model, data); + LayerInputHandle alpha = ConvertToLayerInputHandle(operation, 1, model, data); + + if (!input.IsValid() || !alpha.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const Operand* output = GetOutputOperand(operation, 0, model); + + if (!output) + { + return Fail("%s: Could not read output 0", __func__); + } + + const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); + const armnn::TensorInfo& alphaInfo = alpha.GetTensorInfo(); + const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + + if (!IsLayerSupportedForAnyBackend(__func__, + armnn::IsPreluSupported, + data.m_Backends, + inputInfo, + alphaInfo, + outputInfo)) + { + return false; + } + + armnn::IConnectableLayer* const layer = data.m_Network->AddPreluLayer(); + + if (!layer) + { + return Fail("%s: AddPreluLayer failed", __func__); + } + + input.Connect(layer->GetInputSlot(0)); + alpha.Connect(layer->GetInputSlot(1)); + + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data); +} + } // namespace hal_1_2 -} // namespace armnn_driver \ No newline at end of file +} // namespace armnn_driver diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp index 516e1ebd..0966145e 100644 --- a/1.2/HalPolicy.hpp +++ b/1.2/HalPolicy.hpp @@ -31,7 +31,8 @@ public: private: static bool ConvertConv2d(const Operation& operation, const Model& model, ConversionData& data); static bool ConvertDepthwiseConv2d(const Operation& operation, const Model& model, ConversionData& data); + static bool ConvertPrelu(const Operation& operation, const Model& model, ConversionData& data); }; } // namespace hal_1_2 -} // namespace armnn_driver \ No newline at end of file +} // namespace armnn_driver diff --git a/Android.mk b/Android.mk index 8f598ca7..bee57dd0 100644 --- a/Android.mk +++ b/Android.mk @@ -327,15 +327,15 @@ LOCAL_SRC_FILES := \ 1.1/HalPolicy.cpp \ 1.2/ArmnnDriverImpl.cpp \ 1.2/HalPolicy.cpp \ - ArmnnDriverImpl.cpp \ - DriverOptions.cpp \ ArmnnDevice.cpp \ + ArmnnDriverImpl.cpp \ ArmnnPreparedModel.cpp \ ArmnnPreparedModel_1_2.cpp \ + ConversionUtils.cpp \ + DriverOptions.cpp \ ModelToINetworkConverter.cpp \ RequestThread.cpp \ - Utils.cpp \ - ConversionUtils.cpp + Utils.cpp LOCAL_STATIC_LIBRARIES := \ libneuralnetworks_common \ -- cgit v1.2.1