aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils_1_3.hpp
blob: 5014e7527b4b348ba412267a9c4b0a0e9e4eeb5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
//
// Copyright © 2020 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "ConversionUtils_1_2.hpp"

using Half = half_float::half;

namespace armnn_driver
{

using namespace armnn;
using namespace android::nn;

template<typename HalPolicy,
         typename HalOperation = typename HalPolicy::Operation,
         typename HalModel     = typename HalPolicy::Model>
bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
{
    using HalOperandType = typename HalPolicy::OperandType;

    LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
    if (!input0.IsValid())
    {
        return Fail("%s: Operation has invalid inputs", __func__);
    }

    // Determine data type of input tensor
    HalOperandType inputType;
    if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
    {
        return Fail("%s: Operation has invalid inputs", __func__);
    }

    ActivationDescriptor desc;
    desc.m_Function = ActivationFunction::Elu;

    // Read alpha
    if (inputType == HalOperandType::TENSOR_FLOAT16)
    {
        Half alpha;

        if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
        {
            return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
        }

        desc.m_A = static_cast<float>(alpha);
    }
    else if (inputType == HalOperandType::TENSOR_FLOAT32)
    {
        if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
        {
            return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
        }
    }
    else
    {
        return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
    }

    return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
}

} // armnn_driver namespace