diff options
Diffstat (limited to 'ConversionUtils_1_3.hpp')
-rw-r--r-- | ConversionUtils_1_3.hpp | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp new file mode 100644 index 00000000..5014e752 --- /dev/null +++ b/ConversionUtils_1_3.hpp @@ -0,0 +1,68 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ConversionUtils_1_2.hpp" + +using Half = half_float::half; + +namespace armnn_driver +{ + +using namespace armnn; +using namespace android::nn; + +template<typename HalPolicy, + typename HalOperation = typename HalPolicy::Operation, + typename HalModel = typename HalPolicy::Model> +bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data) +{ + using HalOperandType = typename HalPolicy::OperandType; + + LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data); + if (!input0.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + // Determine data type of input tensor + HalOperandType inputType; + if (!GetOperandType<HalPolicy>(operation, 0, model, inputType)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + ActivationDescriptor desc; + desc.m_Function = ActivationFunction::Elu; + + // Read alpha + if (inputType == HalOperandType::TENSOR_FLOAT16) + { + Half alpha; + + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data)) + { + return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__); + } + + desc.m_A = static_cast<float>(alpha); + } + else if (inputType == HalOperandType::TENSOR_FLOAT32) + { + if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data)) + { + return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__); + } + } + else + { + return Fail("%s: Unsupported input tensor type: %d", __func__, inputType); + } + + return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data); +} + +} // armnn_driver namespace
\ No newline at end of file |