aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-04-01 15:09:39 +0100
committerSadik Armagan <sadik.armagan@arm.com>2020-04-01 15:09:39 +0100
commit1153d1ed778602ec9190641968384bc68083488d (patch)
treecbe4cf64cba0fbba764c603eb8d391d25676d5ba
parent352d83857d28a37821a5a6643f0ed21115c66cd7 (diff)
downloadandroid-nn-driver-1153d1ed778602ec9190641968384bc68083488d.tar.gz
IVGCVSW-4441 Add Support for ANEURALNETWORKS_ELU to HAL 1.3 Driver
* Read alpha parameter for ELU operation * Created ConvertionUtils_1_3 for 1.3 Driver * Added QAsymmS8 data type support to swizzle the tensor Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I9d66a4e8d5468efa305bb8f6c352f13d27602274
-rw-r--r--1.3/HalPolicy.cpp5
-rw-r--r--1.3/HalPolicy.hpp1
-rw-r--r--ConversionUtils_1_3.hpp68
-rw-r--r--Utils.cpp1
4 files changed, 71 insertions, 4 deletions
diff --git a/1.3/HalPolicy.cpp b/1.3/HalPolicy.cpp
index 28d7319..b2b8a86 100644
--- a/1.3/HalPolicy.cpp
+++ b/1.3/HalPolicy.cpp
@@ -230,10 +230,7 @@ bool HalPolicy::ConvertElementwiseUnary(const Operation& operation,
bool HalPolicy::ConvertElu(const Operation& operation, const Model& model, ConversionData& data)
{
ALOGV("hal_1_3::HalPolicy::ConvertElu()");
- ActivationDescriptor desc;
- desc.m_Function = ActivationFunction::Elu;
-
- return ::ConvertToActivation<hal_1_3::HalPolicy>(operation, __func__, desc, model, data);
+ return ::ConvertElu<hal_1_3::HalPolicy>(operation, model, data);
}
bool HalPolicy::ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data)
diff --git a/1.3/HalPolicy.hpp b/1.3/HalPolicy.hpp
index e3f21b1..c601942 100644
--- a/1.3/HalPolicy.hpp
+++ b/1.3/HalPolicy.hpp
@@ -7,6 +7,7 @@
#include "../ConversionUtils.hpp"
#include "../ConversionUtils_1_2.hpp"
+#include "../ConversionUtils_1_3.hpp"
#include <HalInterfaces.h>
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp
new file mode 100644
index 0000000..5014e75
--- /dev/null
+++ b/ConversionUtils_1_3.hpp
@@ -0,0 +1,68 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ConversionUtils_1_2.hpp"
+
+using Half = half_float::half;
+
+namespace armnn_driver
+{
+
+using namespace armnn;
+using namespace android::nn;
+
+template<typename HalPolicy,
+ typename HalOperation = typename HalPolicy::Operation,
+ typename HalModel = typename HalPolicy::Model>
+bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
+{
+ using HalOperandType = typename HalPolicy::OperandType;
+
+ LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
+ if (!input0.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ // Determine data type of input tensor
+ HalOperandType inputType;
+ if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ ActivationDescriptor desc;
+ desc.m_Function = ActivationFunction::Elu;
+
+ // Read alpha
+ if (inputType == HalOperandType::TENSOR_FLOAT16)
+ {
+ Half alpha;
+
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
+ }
+
+ desc.m_A = static_cast<float>(alpha);
+ }
+ else if (inputType == HalOperandType::TENSOR_FLOAT32)
+ {
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
+ }
+ }
+ else
+ {
+ return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
+ }
+
+ return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
+}
+
+} // armnn_driver namespace \ No newline at end of file
diff --git a/Utils.cpp b/Utils.cpp
index 40d6063..00d61c7 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -54,6 +54,7 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::Float32:
case armnn::DataType::QAsymmU8:
case armnn::DataType::QSymmS8:
+ case armnn::DataType::QAsymmS8:
SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
break;
default: