aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-06-27 14:12:55 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2019-07-04 08:29:26 +0000
commit17ffff3f6708340695ca1433ed8b61955e15d7a5 (patch)
tree3924273dc63ec2f4d558b74a9cff06a00568abb9
parent535607ddba89b24cf33f431cdea7128ee01929ed (diff)
downloadandroid-nn-driver-17ffff3f6708340695ca1433ed8b61955e15d7a5.tar.gz
IVGCVSW-3369 Add conversion method to HAL1.2 Policy for PReLU activation
* Added ConvertPrelu method to HalPolicy V1.2 Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com> Change-Id: I8248d2e2bd236295845da665b2d4e687478368ef
-rw-r--r--1.2/HalPolicy.cpp52
-rw-r--r--1.2/HalPolicy.hpp3
-rw-r--r--Android.mk8
3 files changed, 56 insertions, 7 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index b3b1d69f..4e638cf4 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -121,19 +121,23 @@ bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model,
return hal_1_0::HalPolicy::ConvertOperation(v10Operation, v10Model, data);
}
- else if (HandledByV1_1(operation) && compliantWithV1_1(model))
+
+ if (HandledByV1_1(operation) && compliantWithV1_1(model))
{
hal_1_1::HalPolicy::Operation v11Operation = ConvertToV1_1(operation);
hal_1_1::HalPolicy::Model v11Model = convertToV1_1(model);
return hal_1_1::HalPolicy::ConvertOperation(v11Operation, v11Model, data);
}
+
switch (operation.type)
{
case V1_2::OperationType::CONV_2D:
return ConvertConv2d(operation, model, data);
case V1_2::OperationType::DEPTHWISE_CONV_2D:
return ConvertDepthwiseConv2d(operation, model, data);
+ case V1_2::OperationType::PRELU:
+ return ConvertPrelu(operation, model, data);
default:
return Fail("%s: Operation type %s not supported in ArmnnDriver",
__func__, toString(operation.type).c_str());
@@ -418,5 +422,49 @@ bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model&
return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *endLayer, model, data);
}
+bool HalPolicy::ConvertPrelu(const Operation& operation, const Model& model, ConversionData& data)
+{
+ LayerInputHandle input = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 0, model, data);
+ LayerInputHandle alpha = ConvertToLayerInputHandle<hal_1_2::HalPolicy>(operation, 1, model, data);
+
+ if (!input.IsValid() || !alpha.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const Operand* output = GetOutputOperand<hal_1_2::HalPolicy>(operation, 0, model);
+
+ if (!output)
+ {
+ return Fail("%s: Could not read output 0", __func__);
+ }
+
+ const armnn::TensorInfo& inputInfo = input.GetTensorInfo();
+ const armnn::TensorInfo& alphaInfo = alpha.GetTensorInfo();
+ const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+
+ if (!IsLayerSupportedForAnyBackend(__func__,
+ armnn::IsPreluSupported,
+ data.m_Backends,
+ inputInfo,
+ alphaInfo,
+ outputInfo))
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* const layer = data.m_Network->AddPreluLayer();
+
+ if (!layer)
+ {
+ return Fail("%s: AddPreluLayer failed", __func__);
+ }
+
+ input.Connect(layer->GetInputSlot(0));
+ alpha.Connect(layer->GetInputSlot(1));
+
+ return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, model, data);
+}
+
} // namespace hal_1_2
-} // namespace armnn_driver \ No newline at end of file
+} // namespace armnn_driver
diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp
index 516e1ebd..0966145e 100644
--- a/1.2/HalPolicy.hpp
+++ b/1.2/HalPolicy.hpp
@@ -31,7 +31,8 @@ public:
private:
static bool ConvertConv2d(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertDepthwiseConv2d(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertPrelu(const Operation& operation, const Model& model, ConversionData& data);
};
} // namespace hal_1_2
-} // namespace armnn_driver \ No newline at end of file
+} // namespace armnn_driver
diff --git a/Android.mk b/Android.mk
index 8f598ca7..bee57dd0 100644
--- a/Android.mk
+++ b/Android.mk
@@ -327,15 +327,15 @@ LOCAL_SRC_FILES := \
1.1/HalPolicy.cpp \
1.2/ArmnnDriverImpl.cpp \
1.2/HalPolicy.cpp \
- ArmnnDriverImpl.cpp \
- DriverOptions.cpp \
ArmnnDevice.cpp \
+ ArmnnDriverImpl.cpp \
ArmnnPreparedModel.cpp \
ArmnnPreparedModel_1_2.cpp \
+ ConversionUtils.cpp \
+ DriverOptions.cpp \
ModelToINetworkConverter.cpp \
RequestThread.cpp \
- Utils.cpp \
- ConversionUtils.cpp
+ Utils.cpp
LOCAL_STATIC_LIBRARIES := \
libneuralnetworks_common \