aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils_1_3.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils_1_3.hpp')
-rw-r--r--ConversionUtils_1_3.hpp68
1 files changed, 68 insertions, 0 deletions
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp
new file mode 100644
index 00000000..5014e752
--- /dev/null
+++ b/ConversionUtils_1_3.hpp
@@ -0,0 +1,68 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ConversionUtils_1_2.hpp"
+
+using Half = half_float::half;
+
+namespace armnn_driver
+{
+
+using namespace armnn;
+using namespace android::nn;
+
+template<typename HalPolicy,
+ typename HalOperation = typename HalPolicy::Operation,
+ typename HalModel = typename HalPolicy::Model>
+bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
+{
+ using HalOperandType = typename HalPolicy::OperandType;
+
+ LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
+ if (!input0.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ // Determine data type of input tensor
+ HalOperandType inputType;
+ if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ ActivationDescriptor desc;
+ desc.m_Function = ActivationFunction::Elu;
+
+ // Read alpha
+ if (inputType == HalOperandType::TENSOR_FLOAT16)
+ {
+ Half alpha;
+
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
+ }
+
+ desc.m_A = static_cast<float>(alpha);
+ }
+ else if (inputType == HalOperandType::TENSOR_FLOAT32)
+ {
+ if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
+ }
+ }
+ else
+ {
+ return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
+ }
+
+ return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
+}
+
+} // armnn_driver namespace \ No newline at end of file