// // Copyright © 2020 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "ConversionUtils_1_2.hpp" using Half = half_float::half; namespace armnn_driver { using namespace armnn; using namespace android::nn; template bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data) { using HalOperandType = typename HalPolicy::OperandType; LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data); if (!input0.IsValid()) { return Fail("%s: Operation has invalid inputs", __func__); } // Determine data type of input tensor HalOperandType inputType; if (!GetOperandType(operation, 0, model, inputType)) { return Fail("%s: Operation has invalid inputs", __func__); } ActivationDescriptor desc; desc.m_Function = ActivationFunction::Elu; // Read alpha if (inputType == HalOperandType::TENSOR_FLOAT16) { Half alpha; if (!GetInputScalar(operation, 1, HalOperandType::FLOAT16, alpha, model, data)) { return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__); } desc.m_A = static_cast(alpha); } else if (inputType == HalOperandType::TENSOR_FLOAT32) { if (!GetInputScalar(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data)) { return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__); } } else { return Fail("%s: Unsupported input tensor type: %d", __func__, inputType); } return ::ConvertToActivation(operation, __func__, desc, model, data); } } // armnn_driver namespace