From 1153d1ed778602ec9190641968384bc68083488d Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 1 Apr 2020 15:09:39 +0100 Subject: IVGCVSW-4441 Add Support for ANEURALNETWORKS_ELU to HAL 1.3 Driver * Read alpha parameter for ELU operation * Created ConvertionUtils_1_3 for 1.3 Driver * Added QAsymmS8 data type support to swizzle the tensor Signed-off-by: Sadik Armagan Change-Id: I9d66a4e8d5468efa305bb8f6c352f13d27602274 --- ConversionUtils_1_3.hpp | 68 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 ConversionUtils_1_3.hpp (limited to 'ConversionUtils_1_3.hpp') diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp new file mode 100644 index 00000000..5014e752 --- /dev/null +++ b/ConversionUtils_1_3.hpp @@ -0,0 +1,68 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ConversionUtils_1_2.hpp" + +using Half = half_float::half; + +namespace armnn_driver +{ + +using namespace armnn; +using namespace android::nn; + +template +bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data) +{ + using HalOperandType = typename HalPolicy::OperandType; + + LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data); + if (!input0.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + // Determine data type of input tensor + HalOperandType inputType; + if (!GetOperandType(operation, 0, model, inputType)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + ActivationDescriptor desc; + desc.m_Function = ActivationFunction::Elu; + + // Read alpha + if (inputType == HalOperandType::TENSOR_FLOAT16) + { + Half alpha; + + if (!GetInputScalar(operation, 1, HalOperandType::FLOAT16, alpha, model, data)) + { + return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__); + } + + desc.m_A = static_cast(alpha); + } + else if (inputType == HalOperandType::TENSOR_FLOAT32) + { + if (!GetInputScalar(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data)) + { + return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__); + } + } + else + { + return Fail("%s: Unsupported input tensor type: %d", __func__, inputType); + } + + return ::ConvertToActivation(operation, __func__, desc, model, data); +} + +} // armnn_driver namespace \ No newline at end of file -- cgit v1.2.1