aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/Activation.cpp
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-09-24 13:18:27 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commitb4540bef0b0327683fe8e63f727c1212800dc2a9 (patch)
treee1ea8bb6ee981640a1c469ceb556ed648ffde411 /src/backends/reference/workloads/Activation.cpp
parent2d9dd36fb6bc20b370701ab15463359b9db35f18 (diff)
downloadarmnn-b4540bef0b0327683fe8e63f727c1212800dc2a9.tar.gz
IVGCVSW-1898 : Ref backend folder structure
* Reference backend is renamed to backends/reference as per https://confluence.arm.com/display/MLENG/Pluggable+backends Change-Id: I27a13c274eb60995dfb459e3c49c0e2f60bcd32c
Diffstat (limited to 'src/backends/reference/workloads/Activation.cpp')
-rw-r--r--src/backends/reference/workloads/Activation.cpp91
1 files changed, 91 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp
new file mode 100644
index 0000000000..ef4903074b
--- /dev/null
+++ b/src/backends/reference/workloads/Activation.cpp
@@ -0,0 +1,91 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "Activation.hpp"
+
+#include <boost/log/trivial.hpp>
+
+#include <cmath>
+
+namespace armnn
+{
+
+void Activation(const float* in,
+ float* out,
+ const TensorInfo& tensorInfo,
+ ActivationFunction function,
+ float a,
+ float b)
+{
+ for (size_t i = 0; i<tensorInfo.GetNumElements(); i++)
+ {
+ float input = in[i];
+ float output;
+
+ // Compute the result of the activation function.
+ switch (function)
+ {
+ case ActivationFunction::Linear:
+ {
+ output = a * input + b;
+ break;
+ }
+ case ActivationFunction::Sigmoid:
+ {
+ output = 1.f / (1.f + expf(-input));
+ break;
+ }
+ case ActivationFunction::ReLu:
+ {
+ output = std::max(0.f, input);
+ break;
+ }
+ case ActivationFunction::BoundedReLu:
+ {
+ output = std::min(a, std::max(b, input));
+ break;
+ }
+ case ActivationFunction::SoftReLu:
+ {
+ output = logf(1.0f + expf(input));
+ break;
+ }
+ case ActivationFunction::LeakyReLu:
+ {
+ output = input > 0.0f ? input : (input * a);
+ break;
+ }
+ case ActivationFunction::Abs:
+ {
+ output = input < 0 ? -input : input;
+ break;
+ }
+ case ActivationFunction::Sqrt:
+ {
+ output = sqrtf(input);
+ break;
+ }
+ case ActivationFunction::Square:
+ {
+ output = input * input;
+ break;
+ }
+ case ActivationFunction::TanH:
+ {
+ output = a * tanhf(b * input);
+ break;
+ }
+ default:
+ {
+ BOOST_LOG_TRIVIAL(error) << "Unsupported activation function";
+ return;
+ }
+ }
+
+ out[i] = output;
+ }
+}
+
+} //namespace armnn