diff options
-rw-r--r-- | 1.3/HalPolicy.cpp | 5 | ||||
-rw-r--r-- | 1.3/HalPolicy.hpp | 1 | ||||
-rw-r--r-- | ConversionUtils_1_3.hpp | 68 | ||||
-rw-r--r-- | Utils.cpp | 1 |
4 files changed, 71 insertions, 4 deletions
diff --git a/1.3/HalPolicy.cpp b/1.3/HalPolicy.cpp index 28d73197..b2b8a860 100644 --- a/1.3/HalPolicy.cpp +++ b/1.3/HalPolicy.cpp @@ -230,10 +230,7 @@ bool HalPolicy::ConvertElementwiseUnary(const Operation& operation, bool HalPolicy::ConvertElu(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_3::HalPolicy::ConvertElu()"); - ActivationDescriptor desc; - desc.m_Function = ActivationFunction::Elu; - - return ::ConvertToActivation<hal_1_3::HalPolicy>(operation, __func__, desc, model, data); + return ::ConvertElu<hal_1_3::HalPolicy>(operation, model, data); } bool HalPolicy::ConvertExpandDims(const Operation& operation, const Model& model, ConversionData& data) diff --git a/1.3/HalPolicy.hpp b/1.3/HalPolicy.hpp index e3f21b1b..c6019421 100644 --- a/1.3/HalPolicy.hpp +++ b/1.3/HalPolicy.hpp @@ -7,6 +7,7 @@ #include "../ConversionUtils.hpp" #include "../ConversionUtils_1_2.hpp" +#include "../ConversionUtils_1_3.hpp" #include <HalInterfaces.h> diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp new file mode 100644 index 00000000..5014e752 --- /dev/null +++ b/ConversionUtils_1_3.hpp @@ -0,0 +1,68 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ConversionUtils_1_2.hpp" + +using Half = half_float::half; + +namespace armnn_driver +{ + +using namespace armnn; +using namespace android::nn; + +template<typename HalPolicy, + typename HalOperation = typename HalPolicy::Operation, + typename HalModel = typename HalPolicy::Model> +bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data) +{ + using HalOperandType = typename HalPolicy::OperandType; + + LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data); + if (!input0.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + // Determine data type of input tensor + HalOperandType inputType; + if (!GetOperandType<HalPolicy>(operation, 0, model, inputType)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + ActivationDescriptor desc; + desc.m_Function = ActivationFunction::Elu; + + // Read alpha + if (inputType == HalOperandType::TENSOR_FLOAT16) + { + Half alpha; + + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data)) + { + return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__); + } + + desc.m_A = static_cast<float>(alpha); + } + else if (inputType == HalOperandType::TENSOR_FLOAT32) + { + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data)) + { + return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__); + } + } + else + { + return Fail("%s: Unsupported input tensor type: %d", __func__, inputType); + } + + return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data); +} + +} // armnn_driver namespace
\ No newline at end of file @@ -54,6 +54,7 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::Float32: case armnn::DataType::QAsymmU8: case armnn::DataType::QSymmS8: + case armnn::DataType::QAsymmS8: SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); break; default: |